utils.py 414 B

1234567891011121314151617
  1. import torch
  2. import torch.nn as nn
  3. # Necessary for my KFAC implementation.
  4. class AddBias(nn.Module):
  5. def __init__(self, bias):
  6. super(AddBias, self).__init__()
  7. self._bias = nn.Parameter(bias.unsqueeze(1))
  8. def forward(self, x):
  9. if x.dim() == 2:
  10. bias = self._bias.t().view(1, -1)
  11. else:
  12. bias = self._bias.t().view(1, -1, 1, 1)
  13. return x + bias