728x90
class MultiHeadAttention(tf.keras.layers.Layer):
def __init__(self,**kargs):
super(MultiHeadAttention,self).__init__()
self.num_heads = kargs['num_heads']
self.d_model = kargs['d_model']
assert self.d_model % self.num_heads == 0
self.depth = self.d_model // self.num_heads
self.wq = tf.keras.layers.Dense(kargs['d_model'])
self.wk = tf.keras.layers.Dense(kargs['d_model'])
self.wv = tf.keras.layers.Dense(kargs['d_model'])
self.dense = tf.keras.layers.Dense(kargs['d_model'])
def split_haeds(self,x,batch_size):
x = tf.reshape(x,(batch_size,-1,self.num_heads,self.depth))
return tf.transpose(x,perm=[0,2,1,3])
def call(self,v,k,q,mask):
batch_size = tf.shape(q)[0]
q = self.wq(q) # (batch_size, seq_len, d_model)
k = self.wk(k) # (batch_size, seq_len, d_model)
v = self.wv(v) # (batch_size, seq_len, d_model)
q = self.split_haeds(q,batch_size)
k = self.split_haeds(k,batch_size)
v = self.split_haeds(v,batch_size)
scaled_attention, attention_weights = scaled_dot_product_attention(
q,k,v,mask
) # ์ค์ผ์ผ๊ณผ ์ดํ
์
๊ฐ์ค์น๋ฅผ ๊ตฌํจ
scaled_attention = tf.transpose(scaled_attention, perm=[0,2,1,3])
concat_attention = tf.reshape(scaled_attention,(batch_size,-1,self.d_model))
output = self.dense(concat_attention)
return output,attention_weights
.
๋ฐ์ํ
'๐พ Deep Learning' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
์ญ์ ํ (backpropagtion) (0) | 2021.02.26 |
---|---|
ํ์ด์ฌ์ ๋ด๋ฐ (0) | 2021.02.24 |
VAE(Variational autoencoder) ์ข ๋ฅ (0) | 2021.02.21 |
[Transformer] Positional Encoding (3) (0) | 2021.02.20 |
[Transformer] Position-wise Feed-Forward Networks (2) (0) | 2021.02.20 |