model.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import torchvision.models as models
  5. class MultiOutputModel(nn.Module):
  6. def __init__(self, n_color_classes, n_gender_classes, n_article_classes):
  7. super().__init__()
  8. self.base_model = models.mobilenet_v2().features # take the model without classifier
  9. last_channel = models.mobilenet_v2().last_channel # size of the layer before classifier
  10. # the input for the classifier should be two-dimensional, but we will have
  11. # [batch_size, channels, width, height]
  12. # so, let's do the spatial averaging: reduce width and height to 1
  13. self.pool = nn.AdaptiveAvgPool2d((1, 1))
  14. # create separate classifiers for our outputs
  15. self.color = nn.Sequential(
  16. nn.Dropout(p=0.2),
  17. nn.Linear(in_features=last_channel, out_features=n_color_classes)
  18. )
  19. self.gender = nn.Sequential(
  20. nn.Dropout(p=0.2),
  21. nn.Linear(in_features=last_channel, out_features=n_gender_classes)
  22. )
  23. self.article = nn.Sequential(
  24. nn.Dropout(p=0.2),
  25. nn.Linear(in_features=last_channel, out_features=n_article_classes)
  26. )
  27. def forward(self, x):
  28. x = self.base_model(x)
  29. x = self.pool(x)
  30. # reshape from [batch, channels, 1, 1] to [batch, channels] to put it into classifier
  31. x = torch.flatten(x, 1)
  32. return {
  33. 'color': self.color(x),
  34. 'gender': self.gender(x),
  35. 'article': self.article(x)
  36. }
  37. def get_loss(self, net_output, ground_truth):
  38. color_loss = F.cross_entropy(net_output['color'], ground_truth['color_labels'])
  39. gender_loss = F.cross_entropy(net_output['gender'], ground_truth['gender_labels'])
  40. article_loss = F.cross_entropy(net_output['article'], ground_truth['article_labels'])
  41. loss = color_loss + gender_loss + article_loss
  42. return loss, {'color': color_loss, 'gender': gender_loss, 'article': article_loss}