728x90
class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self,**kargs):
        super(MultiHeadAttention,self).__init__()
        self.num_heads = kargs['num_heads']
        self.d_model = kargs['d_model']

        assert self.d_model % self.num_heads == 0

        self.depth = self.d_model // self.num_heads

        self.wq = tf.keras.layers.Dense(kargs['d_model'])
        self.wk = tf.keras.layers.Dense(kargs['d_model'])
        self.wv = tf.keras.layers.Dense(kargs['d_model'])

        self.dense = tf.keras.layers.Dense(kargs['d_model'])

    def split_haeds(self,x,batch_size):
        x = tf.reshape(x,(batch_size,-1,self.num_heads,self.depth))
        return tf.transpose(x,perm=[0,2,1,3])

    def call(self,v,k,q,mask):
        batch_size = tf.shape(q)[0]

        q = self.wq(q) # (batch_size, seq_len, d_model)
        k = self.wk(k) # (batch_size, seq_len, d_model)
        v = self.wv(v) # (batch_size, seq_len, d_model)
        
        q = self.split_haeds(q,batch_size)
        k = self.split_haeds(k,batch_size)
        v = self.split_haeds(v,batch_size)
        
        scaled_attention, attention_weights = scaled_dot_product_attention(
            q,k,v,mask
        ) # ์Šค์ผ€์ผ๊ณผ ์–ดํ…์…˜ ๊ฐ€์ค‘์น˜๋ฅผ ๊ตฌํ•จ
        
        scaled_attention = tf.transpose(scaled_attention, perm=[0,2,1,3])
        concat_attention = tf.reshape(scaled_attention,(batch_size,-1,self.d_model))
        
        output = self.dense(concat_attention)
        
        return output,attention_weights

.

 

๋ฐ˜์‘ํ˜•

'๐Ÿ‘พ Deep Learning' ์นดํ…Œ๊ณ ๋ฆฌ์˜ ๋‹ค๋ฅธ ๊ธ€

์—ญ์ „ํŒŒ (backpropagtion)  (0) 2021.02.26
ํŒŒ์ด์ฌ์˜ ๋‰ด๋Ÿฐ  (0) 2021.02.24
VAE(Variational autoencoder) ์ข…๋ฅ˜  (0) 2021.02.21
[Transformer] Positional Encoding (3)  (0) 2021.02.20
[Transformer] Position-wise Feed-Forward Networks (2)  (0) 2021.02.20
๋‹คํ–ˆ๋‹ค