stylenet.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. import torch
  2. import torch.nn as nn
  3. from torch import optim
  4. import time
  5. import os
  6. from PIL import Image
  7. from torchvision import transforms as T
  8. from torchvision import models
  9. from torch.utils.tensorboard import SummaryWriter
  10. from dataset import DataManager
  11. import config as cfg
  12. import sys
  13. if not sys.platform=='darwin':
  14. from apex import amp
  15. class StyleNetwork(nn.Module):
  16. def __init__(self, loadpath=None):
  17. super(StyleNetwork, self).__init__()
  18. self.loadpath=loadpath
  19. self.layer1 = self.get_conv_module(inc=3, outc=16, ksize=9)
  20. self.layer2 = self.get_conv_module(inc=16, outc=32)
  21. self.layer3 = self.get_conv_module(inc=32, outc=64)
  22. self.layer4 = self.get_conv_module(inc=64, outc=128)
  23. self.connector1=self.get_depthwise_separable_module(128, 128)
  24. self.connector2=self.get_depthwise_separable_module(64, 64)
  25. self.connector3=self.get_depthwise_separable_module(32, 32)
  26. self.layer5 = self.get_deconv_module(256, 64)
  27. self.layer6 = self.get_deconv_module(128, 32)
  28. self.layer7 = self.get_deconv_module(64, 16)
  29. self.layer8 = nn.ConvTranspose2d(16, 3, kernel_size=3, stride=2, padding=1, output_padding=1)
  30. self.activation=nn.Sigmoid()
  31. if self.loadpath:
  32. self.load_state_dict(torch.load(self.loadpath, map_location=torch.device('cpu')))
  33. def get_conv_module(self, inc, outc, ksize=3):
  34. padding=(ksize-1)//2
  35. conv=nn.Conv2d(in_channels=inc, out_channels=outc, kernel_size=ksize, stride=2, padding=padding)
  36. bn=nn.BatchNorm2d(outc)
  37. relu=nn.LeakyReLU(0.1)
  38. return nn.Sequential(conv, bn, relu)
  39. def get_deconv_module(self, inc, outc, ksize=3):
  40. padding=(ksize-1)//2
  41. tconv=nn.ConvTranspose2d(inc, outc, kernel_size=ksize, stride=2, padding=padding, output_padding=padding)
  42. bn=nn.BatchNorm2d(outc)
  43. relu=nn.LeakyReLU(0.1)
  44. return nn.Sequential(tconv, bn, relu)
  45. def get_depthwise_separable_module(self, inc, outc):
  46. """
  47. inc(int): number of input channels
  48. outc(int): number of output channels
  49. Implements a depthwise separable convolution layer
  50. along with batch norm and activation.
  51. Intended to be used with inc=outc in the current architecture
  52. """
  53. depthwise=nn.Conv2d(inc, inc, kernel_size=3, stride=1, padding=1, groups=inc)
  54. pointwise=nn.Conv2d(inc, outc, kernel_size=1, stride=1, padding=0, groups=1)
  55. bn_layer=nn.BatchNorm2d(outc)
  56. activation=nn.LeakyReLU(0.1)
  57. return nn.Sequential(depthwise, pointwise, bn_layer, activation)
  58. def forward(self, x):
  59. x=self.layer1(x)
  60. x2=self.layer2(x)
  61. x3=self.layer3(x2)
  62. x4=self.layer4(x3)
  63. xs4=self.connector1(x4)
  64. xs3=self.connector2(x3)
  65. xs2=self.connector3(x2)
  66. c1=torch.cat([x4, xs4], dim=1)
  67. x5=self.layer5(c1)
  68. c2=torch.cat([x5, xs3], dim=1)
  69. x6=self.layer6(c2)
  70. c3=torch.cat([x6, xs2], dim=1)
  71. x7=self.layer7(c3)
  72. out=self.layer8(x7)
  73. out=self.activation(out)
  74. return out
  75. class StyleLoss(nn.Module):
  76. def __init__(self):
  77. super(StyleLoss, self).__init__()
  78. pass
  79. def forward(self, target_features, output_features):
  80. loss=0
  81. for target_f,out_f in zip(target_features, output_features):
  82. #target is batch size 1
  83. t_bs,t_ch,t_w,t_h=target_f.shape
  84. assert t_bs ==1, 'Network should be trained for only one target image'
  85. target_f=target_f.reshape(t_ch, t_w*t_h)
  86. target_gram_matrix=torch.matmul(target_f,target_f.T)/(t_ch*t_w*t_h) #t_ch x t_ch matrix
  87. i_bs, i_ch, i_w, i_h = out_f.shape
  88. assert t_ch == i_ch, 'Bug'
  89. for img_f in out_f: #contains features for batch of images
  90. img_f=img_f.reshape(i_ch, i_w*i_h)
  91. img_gram_matrix=torch.matmul(img_f, img_f.T)/(i_ch*i_w*i_h)
  92. loss+= torch.square(target_gram_matrix - img_gram_matrix).mean()
  93. return loss
  94. class ContentLoss(nn.Module):
  95. def __init__(self):
  96. super(ContentLoss, self).__init__()
  97. def forward(self, style_features, content_features):
  98. loss=0
  99. for sf,cf in zip(style_features, content_features):
  100. a,b,c,d=sf.shape
  101. loss+=(torch.square(sf-cf)/(a*b*c*d)).mean()
  102. return loss
  103. class TotalVariationLoss(nn.Module):
  104. def __init__(self):
  105. super(TotalVariationLoss, self).__init__()
  106. def forward(self, x):
  107. horizontal_loss=torch.pow(x[...,1:,:]-x[...,:-1,:],2).sum()
  108. vertical_loss=torch.pow(x[...,1:]-x[...,:-1],2).sum()
  109. return (horizontal_loss+vertical_loss)/x.numel()
  110. class StyleTrainer(object):
  111. def __init__(self, student_network, loss_network, style_target_path, data_manager,feature_loss, style_loss, savepath=None):
  112. self.student_network=student_network
  113. self.loss_network=loss_network
  114. assert os.path.exists(style_target_path), 'Style target does not exist'
  115. image=Image.open(style_target_path).convert('RGB').resize(cfg.SIZE[::-1])
  116. preprocess=T.Compose([T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
  117. self.style_target=torch.unsqueeze(preprocess(image),0)
  118. self.manager=data_manager
  119. self.feature_loss=feature_loss
  120. self.style_loss=style_loss
  121. self.total_variation = TotalVariationLoss()
  122. self.savepath=savepath
  123. self.writer=SummaryWriter()
  124. self.optimizer=optim.Adam(self.student_network.parameters(), lr=cfg.LR)
  125. def train(self, epochs=None, eval_interval=None, style_loss_weight=1.0):
  126. pass
  127. epochs= epochs if epochs else cfg.EPOCHS
  128. eval_interval=eval_interval if eval_interval else cfg.EVAL_INTERVAL
  129. device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
  130. train_loader, valid_loader, *_ = self.manager.dataloaders #ignore test loader if any
  131. self.student_network.to(device).train()
  132. self.loss_network.to(device)
  133. self.loss_network.eval()
  134. self.student_network, self.optimizer = amp.initialize(self.student_network, self.optimizer,
  135. opt_level='O2', enabled=True)
  136. self.style_target=self.style_target.to(device)
  137. style_target_features=resnet_forward(self.loss_network,self.style_target) #fixed during training
  138. step=0
  139. for epoch in range(epochs):
  140. estart=time.time()
  141. for x in train_loader:
  142. self.optimizer.zero_grad()
  143. x=x.to(device)
  144. stylized_image = self.student_network(x)
  145. content_features = resnet_forward(self.loss_network, x) #self.loss_network(x)
  146. stylized_features= resnet_forward(self.loss_network, stylized_image)#self.loss_network(stylized_image)
  147. feature_loss=self.feature_loss(stylized_features, content_features)
  148. style_loss=self.style_loss(style_target_features, content_features)
  149. tvloss=self.total_variation(stylized_image)
  150. loss = 1000*feature_loss + style_loss_weight*style_loss + 0.02*tvloss
  151. self.writer.add_scalar('Feature loss', feature_loss.item(), step)
  152. self.writer.add_scalar('Style loss', style_loss.item(), step)
  153. self.writer.add_scalar('Total Variation Loss', tvloss.item(), step)
  154. #loss.backward()
  155. with amp.scale_loss(loss, self.optimizer) as scaled_loss:
  156. scaled_loss.backward()
  157. self.optimizer.step()
  158. step+=1
  159. if step%eval_interval==0:
  160. self.student_network.eval()
  161. with torch.no_grad():
  162. pass
  163. for imgs in valid_loader:
  164. imgs=imgs.to(device)
  165. stylized=self.student_network(imgs)
  166. self.writer.add_images('Stylized Examples', stylized, step)
  167. break #just one batch is enough
  168. self.student_network.train()
  169. self.save(epoch)
  170. eend=time.time()
  171. print('Time taken for last epoch = {:.3f}'.format(eend-estart))
  172. def save(self, epoch):
  173. if self.savepath:
  174. path=self.savepath.format(epoch)
  175. torch.save(self.student_network.state_dict(), path)
  176. print(f'Saved model to {path}')
  177. def resnet_forward(net, x):
  178. layers_used=['layer1', 'layer2', 'layer3', 'layer4']
  179. output=[]
  180. #print(net._modules.keys())
  181. for name, module in net._modules.items():
  182. if name=='fc':
  183. continue #dont run fc layer since _modules does not include flatten
  184. x=module(x)
  185. if name in layers_used:
  186. output.append(x)
  187. #print('Resnet forward method called')
  188. #[print(q.shape) for q in output]
  189. return output
  190. if __name__=="__main__":
  191. net=StyleNetwork()
  192. manager=DataManager(cfg.IMGPATH_FILE, None, cfg.SIZE) #Datamanager without soft targets
  193. styleloss=StyleLoss()
  194. contentloss=ContentLoss()
  195. loss_network= models.resnet18()
  196. loss_network.load_state_dict(torch.load(cfg.LOSS_NET_PATH)['model'])
  197. for p in loss_network.parameters():
  198. p.requires_grad=False #freeze loss network
  199. trainer=StyleTrainer(net, loss_network,cfg.STYLE_TARGET, manager, contentloss, styleloss, './style_{}.pth')
  200. trainer.train()