vit.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import torch
  2. from torch import nn
  3. from einops.layers.torch import Rearrange
  4. from mhsa import MultiHeadedSelfAttention
  5. import vitconfigs as vcfg
  6. class TransformerEncoder(nn.Module):
  7. '''
  8. Although torch has a nn.Transformer class, it includes both encoder and decoder layers
  9. (with cross attention). Since ViT requires only the encoder, we can't use nn.Transformer.
  10. So, we define a new class
  11. '''
  12. def __init__(self, nheads, nlayers, embed_dim, head_dim, mlp_hdim, dropout):
  13. '''
  14. nheads: (int) number of heads in MSA layer
  15. nlayers: (int) number of MSA layers in the transformer
  16. embed_dim: (int) dimension of input tokens
  17. head_dim: (int) dimensionality of each attention head
  18. mlp_hdim: (int) number of hidden dimensions in hidden layer
  19. dropout: (float 0~1) probability of dropping a node
  20. '''
  21. super(TransformerEncoder, self).__init__()
  22. self.nheads=nheads
  23. self.nlayers=nlayers
  24. self.embed_dim=embed_dim
  25. self.head_dim=head_dim
  26. self.mlp_hdim=mlp_hdim
  27. self.drop_prob=dropout
  28. self.salayers, self.fflayers=self.getlayers()
  29. def getlayers(self):
  30. samodules=nn.ModuleList()
  31. ffmodules=nn.ModuleList()
  32. for _ in range(self.nlayers):
  33. sam=nn.Sequential(
  34. nn.LayerNorm(self.embed_dim),
  35. MultiHeadedSelfAttention(
  36. self.embed_dim,
  37. self.head_dim,
  38. self.nheads,
  39. self.drop_prob
  40. )
  41. )
  42. samodules.append(sam)
  43. ffm=nn.Sequential(
  44. nn.LayerNorm(self.embed_dim),
  45. nn.Linear(self.embed_dim, self.mlp_hdim),
  46. nn.GELU(),
  47. nn.Dropout(self.drop_prob),
  48. nn.Linear(self.mlp_hdim, self.embed_dim),
  49. nn.Dropout(self.drop_prob)
  50. )
  51. ffmodules.append(ffm)
  52. return samodules, ffmodules
  53. def forward(self, x):
  54. for (sal,ffl) in zip(self.salayers, self.fflayers):
  55. x = x+sal(x)
  56. x = x+ffl(x)
  57. return x
  58. class VisionTransformer(nn.Module):
  59. def __init__(self, cfg):
  60. super(VisionTransformer, self).__init__()
  61. input_size=cfg['input_size']
  62. self.patch_size=cfg['patch_size']
  63. self.embed_dim=cfg['embed_dim']
  64. salayers=cfg['salayers']
  65. nheads=cfg['nheads']
  66. head_dim=cfg['head_dim']
  67. mlp_hdim=cfg['mlp_hdim']
  68. drop_prob=cfg['drop_prob']
  69. nclasses=cfg['nclasses']
  70. self.num_patches=(input_size[0]//self.patch_size)*(input_size[1]//self.patch_size) + 1
  71. self.patch_embedding=nn.Sequential(
  72. Rearrange('b c (h px) (w py) -> b (h w) (px py c)', px=self.patch_size, py=self.patch_size),
  73. nn.Linear(self.patch_size*self.patch_size*3, self.embed_dim)
  74. )
  75. self.dropout_layer=nn.Dropout(drop_prob)
  76. self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
  77. # similar to BERT, the cls token is introduced as a learnable parameter
  78. # at the beginning of the ViT model. This token is evolved with self attention
  79. # and finally used to classify the image at the end. Tokens from all patches
  80. # are IGNORED.
  81. self.positional_embedding=nn.Parameter(torch.randn(1, self.num_patches, self.embed_dim))
  82. #Learnable position embedding
  83. self.transformer=TransformerEncoder(
  84. nheads=nheads,
  85. nlayers=salayers,
  86. embed_dim=self.embed_dim,
  87. head_dim=head_dim,
  88. mlp_hdim=mlp_hdim,
  89. dropout=drop_prob
  90. )
  91. self.prediction_head=nn.Sequential(nn.LayerNorm(self.embed_dim), nn.Linear(self.embed_dim, nclasses))
  92. def forward(self, x):
  93. #x is in NCHW format
  94. npatches=(x.size(2)//self.patch_size)*(x.size(3)//self.patch_size) + 1
  95. embed = self.patch_embedding(x)
  96. x=torch.cat((self.cls_token.repeat(x.size(0),1,1), embed), dim=1)
  97. #repeat class token for every sample in batch and cat along patch dimension, so class token is trated just like any patch
  98. if npatches==self.num_patches:
  99. x+=self.positional_embedding
  100. #this will work only if size of input image is same as that specified in the constructor
  101. else:
  102. interpolated=nn.functional.interpolate(
  103. self.positional_embedding[None], #insert dummy dimension
  104. (npatches, self.embed_dim),
  105. mode='bilinear'
  106. )
  107. #we use bilinear but only linear interp will be used
  108. x+=interpolated[0] #remove dummy dimension
  109. x=self.dropout_layer(x)
  110. x= self.transformer(x)
  111. x= x[:,0,:]
  112. #use the first token for classification and ignore everything else
  113. pred=self.prediction_head(x)
  114. return pred
  115. if __name__ == '__main__':
  116. net=VisionTransformer(vcfg.base)
  117. nparams=sum(p.numel() for p in net.parameters() if p.requires_grad)
  118. print(f'Created model with {nparams} parameters')
  119. x=torch.randn(1,3,224,224)
  120. y=net(x)
  121. print(y.shape)
  122. print('Verified Vision Transformer forward pass')