train_classifier_onehot.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. #!/usr/bin/env python3
  2. import time
  3. import random
  4. import numpy as np
  5. import gym
  6. from gym_minigrid.register import env_list
  7. from gym_minigrid.minigrid import Grid, OBJECT_TO_IDX
  8. from gym_minigrid.wrappers import *
  9. import babyai
  10. import torch
  11. import torch.nn as nn
  12. import torch.optim as optim
  13. import torch.nn.functional as F
  14. from torch.autograd import Variable
  15. import torchvision
  16. import numpy as np
  17. import cv2
  18. import PIL
  19. ##############################################################################
  20. def make_var(arr):
  21. arr = np.ascontiguousarray(arr)
  22. #arr = torch.from_numpy(arr).float()
  23. arr = torch.from_numpy(arr)
  24. arr = Variable(arr)
  25. if torch.cuda.is_available():
  26. arr = arr.cuda()
  27. return arr
  28. def init_weights(m):
  29. classname = m.__class__.__name__
  30. if classname.startswith('Conv'):
  31. nn.init.orthogonal_(m.weight.data)
  32. m.bias.data.fill_(0)
  33. elif classname.find('Linear') != -1:
  34. nn.init.xavier_uniform_(m.weight)
  35. m.bias.data.fill_(0)
  36. elif classname.find('BatchNorm') != -1:
  37. m.weight.data.normal_(1.0, 0.02)
  38. m.bias.data.fill_(0)
  39. class Flatten(nn.Module):
  40. """
  41. Flatten layer, to flatten convolutional layer output
  42. """
  43. def forward(self, input):
  44. return input.view(input.size(0), -1)
  45. class Model(nn.Module):
  46. def __init__(self):
  47. super().__init__()
  48. self.layers = nn.Sequential(
  49. nn.Conv2d(in_channels=20, out_channels=32, kernel_size=1),
  50. nn.LeakyReLU(),
  51. nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1),
  52. nn.LeakyReLU(),
  53. nn.Conv2d(in_channels=32, out_channels=2, kernel_size=1),
  54. nn.LeakyReLU(),
  55. nn.Conv2d(in_channels=2, out_channels=2, kernel_size=7),
  56. nn.LeakyReLU(),
  57. Flatten(),
  58. # Two output heads, one for each class
  59. nn.Linear(2, 2)
  60. )
  61. self.apply(init_weights)
  62. def forward(self, obs):
  63. obs = obs / 16
  64. out = self.layers(obs)
  65. return out
  66. def present_prob(self, obs):
  67. obs = make_var(obs).unsqueeze(0)
  68. logits = self(obs)
  69. probs = F.softmax(logits, dim=-1)
  70. probs = probs.detach().cpu().squeeze().numpy()
  71. return probs[1]
  72. seed = 0
  73. env = gym.make('BabyAI-GoToRedBall-v0')
  74. env = OneHotPartialObsWrapper(env)
  75. def sample_batch(batch_size=128):
  76. global seed
  77. imgs = []
  78. labels = []
  79. for i in range(batch_size):
  80. seed += 1
  81. env.seed(i)
  82. unwrapped_obs = env.unwrapped.reset()['image']
  83. env.seed(i)
  84. obs = env.reset()['image']
  85. ball_visible = ('red', 'ball') in Grid.decode(unwrapped_obs)
  86. obs = obs.transpose([2, 0, 1])
  87. imgs.append(np.copy(obs))
  88. labels.append(ball_visible)
  89. imgs = np.stack(imgs).astype(np.float32)
  90. labels = np.array(labels, dtype=np.long)
  91. return imgs, labels
  92. print('Generating test set')
  93. test_imgs, test_labels = sample_batch(256)
  94. def eval_model(model):
  95. num_true = 0
  96. for idx in range(test_imgs.shape[0]):
  97. img = test_imgs[idx]
  98. label = test_labels[idx]
  99. p = model.present_prob(img)
  100. out_label = p > 0.5
  101. #print(out_label)
  102. if np.equal(out_label, label):
  103. num_true += 1
  104. #else:
  105. # if label:
  106. # print("incorrectly predicted as absent")
  107. # else:
  108. # print("incorrectly predicted as present")
  109. acc = 100 * (num_true / test_imgs.shape[0])
  110. return acc
  111. ##############################################################################
  112. batch_size = 128
  113. model = Model()
  114. model.cuda()
  115. optimizer = optim.Adam(
  116. model.parameters(),
  117. lr=5e-4
  118. )
  119. criterion = nn.CrossEntropyLoss()
  120. running_loss = None
  121. for batch_no in range(1, 10000):
  122. batch_imgs, labels = sample_batch(batch_size)
  123. batch_imgs = make_var(batch_imgs)
  124. labels = make_var(labels)
  125. pred = model(batch_imgs)
  126. loss = criterion(pred, labels)
  127. optimizer.zero_grad()
  128. loss.backward()
  129. optimizer.step()
  130. loss = loss.data.detach().item()
  131. running_loss = loss if running_loss is None else 0.99 * running_loss + 0.01 * loss
  132. print('batch #{}, frames={}, loss={:.5f}'.format(
  133. batch_no,
  134. batch_no * batch_size,
  135. running_loss
  136. ))
  137. if batch_no % 25 == 0:
  138. acc = eval_model(model)
  139. print('accuracy: {:.2f}%'.format(acc))