kfac.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. import math
  2. import torch
  3. import torch.optim as optim
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from utils import AddBias
  7. # TODO: In order to make this code faster:
  8. # 1) Implement _extract_patches as a single cuda kernel
  9. # 2) Compute QR decomposition in a separate process
  10. # 3) Actually make a general KFAC optimizer so it fits PyTorch
  11. def _extract_patches(x, kernel_size, stride, padding):
  12. if padding[0] + padding[1] > 0:
  13. x = F.pad(x, (padding[1], padding[1], padding[0],
  14. padding[0])).data # Actually check dims
  15. x = x.unfold(2, kernel_size[0], stride[0])
  16. x = x.unfold(3, kernel_size[1], stride[1])
  17. x = x.transpose_(1, 2).transpose_(2, 3).contiguous()
  18. x = x.view(
  19. x.size(0), x.size(1), x.size(2), x.size(3) * x.size(4) * x.size(5))
  20. return x
  21. def compute_cov_a(a, classname, layer_info, fast_cnn):
  22. batch_size = a.size(0)
  23. if classname == 'Conv2d':
  24. if fast_cnn:
  25. a = _extract_patches(a, *layer_info)
  26. a = a.view(a.size(0), -1, a.size(-1))
  27. a = a.mean(1)
  28. else:
  29. a = _extract_patches(a, *layer_info)
  30. a = a.view(-1, a.size(-1)).div_(a.size(1)).div_(a.size(2))
  31. elif classname == 'AddBias':
  32. is_cuda = a.is_cuda
  33. a = torch.ones(a.size(0), 1)
  34. if is_cuda:
  35. a = a.cuda()
  36. return a.t() @ (a / batch_size)
  37. def compute_cov_g(g, classname, layer_info, fast_cnn):
  38. batch_size = g.size(0)
  39. if classname == 'Conv2d':
  40. if fast_cnn:
  41. g = g.view(g.size(0), g.size(1), -1)
  42. g = g.sum(-1)
  43. else:
  44. g = g.transpose(1, 2).transpose(2, 3).contiguous()
  45. g = g.view(-1, g.size(-1)).mul_(g.size(1)).mul_(g.size(2))
  46. elif classname == 'AddBias':
  47. g = g.view(g.size(0), g.size(1), -1)
  48. g = g.sum(-1)
  49. g_ = g * batch_size
  50. return g_.t() @ (g_ / g.size(0))
  51. def update_running_stat(aa, m_aa, momentum):
  52. # Do the trick to keep aa unchanged and not create any additional tensors
  53. m_aa *= momentum / (1 - momentum)
  54. m_aa += aa
  55. m_aa *= (1 - momentum)
  56. class SplitBias(nn.Module):
  57. def __init__(self, module):
  58. super(SplitBias, self).__init__()
  59. self.module = module
  60. self.add_bias = AddBias(module.bias.data)
  61. self.module.bias = None
  62. def forward(self, input):
  63. x = self.module(input)
  64. x = self.add_bias(x)
  65. return x
  66. class KFACOptimizer(optim.Optimizer):
  67. def __init__(self,
  68. model,
  69. lr=0.25,
  70. momentum=0.9,
  71. stat_decay=0.99,
  72. kl_clip=0.001,
  73. damping=1e-2,
  74. weight_decay=0,
  75. fast_cnn=False,
  76. Ts=1,
  77. Tf=10):
  78. defaults = dict()
  79. def split_bias(module):
  80. for mname, child in module.named_children():
  81. if hasattr(child, 'bias'):
  82. module._modules[mname] = SplitBias(child)
  83. else:
  84. split_bias(child)
  85. split_bias(model)
  86. super(KFACOptimizer, self).__init__(model.parameters(), defaults)
  87. self.known_modules = {'Linear', 'Conv2d', 'AddBias'}
  88. self.modules = []
  89. self.grad_outputs = {}
  90. self.model = model
  91. self._prepare_model()
  92. self.steps = 0
  93. self.m_aa, self.m_gg = {}, {}
  94. self.Q_a, self.Q_g = {}, {}
  95. self.d_a, self.d_g = {}, {}
  96. self.momentum = momentum
  97. self.stat_decay = stat_decay
  98. self.lr = lr
  99. self.kl_clip = kl_clip
  100. self.damping = damping
  101. self.weight_decay = weight_decay
  102. self.fast_cnn = fast_cnn
  103. self.Ts = Ts
  104. self.Tf = Tf
  105. self.optim = optim.SGD(
  106. model.parameters(),
  107. lr=self.lr * (1 - self.momentum),
  108. momentum=self.momentum)
  109. def _save_input(self, module, input):
  110. if input[0].volatile == False and self.steps % self.Ts == 0:
  111. classname = module.__class__.__name__
  112. layer_info = None
  113. if classname == 'Conv2d':
  114. layer_info = (module.kernel_size, module.stride,
  115. module.padding)
  116. aa = compute_cov_a(input[0].data, classname, layer_info,
  117. self.fast_cnn)
  118. # Initialize buffers
  119. if self.steps == 0:
  120. self.m_aa[module] = aa.clone()
  121. update_running_stat(aa, self.m_aa[module], self.stat_decay)
  122. def _save_grad_output(self, module, grad_input, grad_output):
  123. if self.acc_stats:
  124. classname = module.__class__.__name__
  125. layer_info = None
  126. if classname == 'Conv2d':
  127. layer_info = (module.kernel_size, module.stride,
  128. module.padding)
  129. gg = compute_cov_g(grad_output[0].data, classname,
  130. layer_info, self.fast_cnn)
  131. # Initialize buffers
  132. if self.steps == 0:
  133. self.m_gg[module] = gg.clone()
  134. update_running_stat(gg, self.m_gg[module], self.stat_decay)
  135. def _prepare_model(self):
  136. for module in self.model.modules():
  137. classname = module.__class__.__name__
  138. if classname in self.known_modules:
  139. assert not ((classname in ['Linear', 'Conv2d']) and module.bias is not None), \
  140. "You must have a bias as a separate layer"
  141. self.modules.append(module)
  142. module.register_forward_pre_hook(self._save_input)
  143. module.register_backward_hook(self._save_grad_output)
  144. def step(self):
  145. # Add weight decay
  146. if self.weight_decay > 0:
  147. for p in self.model.parameters():
  148. p.grad.data.add_(self.weight_decay, p.data)
  149. updates = {}
  150. for i, m in enumerate(self.modules):
  151. assert len(list(m.parameters())
  152. ) == 1, "Can handle only one parameter at the moment"
  153. classname = m.__class__.__name__
  154. p = next(m.parameters())
  155. la = self.damping + self.weight_decay
  156. if self.steps % self.Tf == 0:
  157. # My asynchronous implementation exists, I will add it later.
  158. # Experimenting with different ways to this in PyTorch.
  159. self.d_a[m], self.Q_a[m] = torch.symeig(
  160. self.m_aa[m], eigenvectors=True)
  161. self.d_g[m], self.Q_g[m] = torch.symeig(
  162. self.m_gg[m], eigenvectors=True)
  163. self.d_a[m].mul_((self.d_a[m] > 1e-6).float())
  164. self.d_g[m].mul_((self.d_g[m] > 1e-6).float())
  165. if classname == 'Conv2d':
  166. p_grad_mat = p.grad.data.view(p.grad.data.size(0), -1)
  167. else:
  168. p_grad_mat = p.grad.data
  169. v1 = self.Q_g[m].t() @ p_grad_mat @ self.Q_a[m]
  170. v2 = v1 / (
  171. self.d_g[m].unsqueeze(1) * self.d_a[m].unsqueeze(0) + la)
  172. v = self.Q_g[m] @ v2 @ self.Q_a[m].t()
  173. v = v.view(p.grad.data.size())
  174. updates[p] = v
  175. vg_sum = 0
  176. for p in self.model.parameters():
  177. v = updates[p]
  178. vg_sum += (v * p.grad.data * self.lr * self.lr).sum()
  179. nu = min(1, math.sqrt(self.kl_clip / vg_sum))
  180. for p in self.model.parameters():
  181. v = updates[p]
  182. p.grad.data.copy_(v)
  183. p.grad.data.mul_(nu)
  184. self.optim.step()
  185. self.steps += 1