mhsa.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. from torch import nn
  2. from einops.layers.torch import Rearrange
  3. class MultiHeadedSelfAttention(nn.Module):
  4. def __init__(self, indim, adim, nheads, drop):
  5. '''
  6. indim: (int) dimension of input vector
  7. adim: (int) dimensionality of each attention head
  8. nheads: (int) number of heads in MHA layer
  9. drop: (float 0~1) probability of dropping a node
  10. Implements QKV MSA layer
  11. output = softmax(Q*K/sqrt(d))*V
  12. scale= 1/sqrt(d), here, d = adim
  13. '''
  14. super(MultiHeadedSelfAttention, self).__init__()
  15. hdim=adim*nheads
  16. self.scale= hdim** -0.5 #scale in softmax(Q*K*scale)*V
  17. self.key_lyr = self.get_qkv_layer(indim, hdim, nheads)
  18. #nn.Linear(indim, hdim, bias=False)
  19. #there should be nheads layers
  20. self.query_lyr=self.get_qkv_layer(indim, hdim, nheads)
  21. self.value_lyr=self.get_qkv_layer(indim, hdim, nheads)
  22. self.attention_scores=nn.Softmax(dim=-1)
  23. self.dropout=nn.Dropout(drop)
  24. self.out_layer=nn.Sequential(Rearrange('bsize nheads indim hdim -> bsize indim (nheads hdim)'),
  25. nn.Linear(hdim, indim),
  26. nn.Dropout(drop))
  27. def get_qkv_layer(self, indim, hdim, nheads):
  28. '''
  29. returns query, key, value layer (call this function thrice to get all of q, k & v layers)
  30. '''
  31. layer=nn.Sequential(nn.Linear(indim, hdim, bias=False),
  32. Rearrange('bsize indim (nheads hdim) -> bsize nheads indim hdim', nheads=nheads))
  33. return layer
  34. def forward(self, x):
  35. query=self.key_lyr(x)
  36. key=self.query_lyr(x)
  37. value=self.value_lyr(x)
  38. dotp=torch.matmul(query, key.transpose(-1, -2))*self.scale
  39. scores=self.attention_scores(dotp)
  40. scores=self.dropout(scores)
  41. weighted=torch.matmul(scores, value)
  42. out=self.out_layer(weighted)
  43. return out