train_classifier.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. #!/usr/bin/env python3
  2. import time
  3. import random
  4. import numpy as np
  5. import gym
  6. from gym_minigrid.register import env_list
  7. from gym_minigrid.minigrid import Grid, OBJECT_TO_IDX
  8. import babyai
  9. import torch
  10. import torch.nn as nn
  11. import torch.optim as optim
  12. import torch.nn.functional as F
  13. from torch.autograd import Variable
  14. import torchvision
  15. import numpy as np
  16. import cv2
  17. import PIL
  18. ##############################################################################
  19. def make_var(arr):
  20. arr = np.ascontiguousarray(arr)
  21. #arr = torch.from_numpy(arr).float()
  22. arr = torch.from_numpy(arr)
  23. arr = Variable(arr)
  24. if torch.cuda.is_available():
  25. arr = arr.cuda()
  26. return arr
  27. def init_weights(m):
  28. classname = m.__class__.__name__
  29. if classname.startswith('Conv'):
  30. nn.init.orthogonal_(m.weight.data)
  31. m.bias.data.fill_(0)
  32. elif classname.find('Linear') != -1:
  33. nn.init.xavier_uniform_(m.weight)
  34. m.bias.data.fill_(0)
  35. elif classname.find('BatchNorm') != -1:
  36. m.weight.data.normal_(1.0, 0.02)
  37. m.bias.data.fill_(0)
  38. class ImageBOWEmbedding(nn.Module):
  39. def __init__(self, num_embeddings, embedding_dim, padding_idx=None, reduce_fn=torch.mean):
  40. super(ImageBOWEmbedding, self).__init__()
  41. self.num_embeddings = num_embeddings
  42. self.embedding_dim = embedding_dim
  43. self.padding_idx = padding_idx
  44. self.reduce_fn = reduce_fn
  45. self.embedding = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
  46. def forward(self, inputs):
  47. embeddings = self.embedding(inputs.long())
  48. embeddings = self.reduce_fn(embeddings, dim=1)
  49. embeddings = torch.transpose(embeddings, 1, 3)
  50. embeddings = torch.transpose(embeddings, 2, 3)
  51. return embeddings
  52. class Flatten(nn.Module):
  53. """
  54. Flatten layer, to flatten convolutional layer output
  55. """
  56. def forward(self, input):
  57. return input.view(input.size(0), -1)
  58. class Model(nn.Module):
  59. def __init__(self):
  60. super().__init__()
  61. self.layers = nn.Sequential(
  62. #ImageBOWEmbedding(765, embedding_dim=16, padding_idx=0, reduce_fn=torch.mean),
  63. #nn.Conv2d(in_channels=16, out_channels=64, kernel_size=1),
  64. #nn.LeakyReLU(),
  65. nn.Conv2d(in_channels=3, out_channels=64, kernel_size=1),
  66. nn.LeakyReLU(),
  67. nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1),
  68. nn.LeakyReLU(),
  69. nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1),
  70. nn.LeakyReLU(),
  71. nn.Conv2d(in_channels=2, out_channels=2, kernel_size=7),
  72. nn.LeakyReLU(),
  73. Flatten(),
  74. # Two output heads, one for each class
  75. nn.Linear(2, 2)
  76. )
  77. self.apply(init_weights)
  78. def forward(self, obs):
  79. obs = obs / 16
  80. out = self.layers(obs)
  81. return out
  82. def present_prob(self, obs):
  83. obs = make_var(obs).unsqueeze(0)
  84. logits = self(obs)
  85. probs = F.softmax(logits, dim=-1)
  86. probs = probs.detach().cpu().squeeze().numpy()
  87. return probs[1]
  88. env = gym.make('BabyAI-GoToRedBall-v0')
  89. def sample_batch(batch_size=128):
  90. imgs = []
  91. labels = []
  92. for i in range(batch_size):
  93. obs = env.reset()['image']
  94. ball_visible = ('red', 'ball') in Grid.decode(obs)
  95. obs = obs.transpose([2, 0, 1])
  96. imgs.append(np.copy(obs))
  97. labels.append(ball_visible)
  98. imgs = np.stack(imgs).astype(np.float32)
  99. labels = np.array(labels, dtype=np.long)
  100. return imgs, labels
  101. print('Generating test set')
  102. test_imgs, test_labels = sample_batch(256)
  103. def eval_model(model):
  104. num_true = 0
  105. for idx in range(test_imgs.shape[0]):
  106. img = test_imgs[idx]
  107. label = test_labels[idx]
  108. p = model.present_prob(img)
  109. out_label = p > 0.5
  110. #print(out_label)
  111. if np.equal(out_label, label):
  112. num_true += 1
  113. #else:
  114. # if label:
  115. # print("incorrectly predicted as absent")
  116. # else:
  117. # print("incorrectly predicted as present")
  118. acc = 100 * (num_true / test_imgs.shape[0])
  119. return acc
  120. ##############################################################################
  121. batch_size = 128
  122. model = Model()
  123. model.cuda()
  124. optimizer = optim.Adam(
  125. model.parameters(),
  126. lr=5e-4
  127. )
  128. criterion = nn.CrossEntropyLoss()
  129. running_loss = None
  130. for batch_no in range(1, 10000):
  131. batch_imgs, labels = sample_batch(batch_size)
  132. batch_imgs = make_var(batch_imgs)
  133. labels = make_var(labels)
  134. pred = model(batch_imgs)
  135. loss = criterion(pred, labels)
  136. optimizer.zero_grad()
  137. loss.backward()
  138. optimizer.step()
  139. loss = loss.data.detach().item()
  140. running_loss = loss if running_loss is None else 0.99 * running_loss + 0.01 * loss
  141. print('batch #{}, frames={}, loss={:.5f}'.format(
  142. batch_no,
  143. batch_no * batch_size,
  144. running_loss
  145. ))
  146. if batch_no % 25 == 0:
  147. acc = eval_model(model)
  148. print('accuracy: {:.2f}%'.format(acc))