utils.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. import torch
  2. import torch.nn as nn
  3. # Necessary for my KFAC implementation.
  4. class AddBias(nn.Module):
  5. def __init__(self, bias):
  6. super(AddBias, self).__init__()
  7. self._bias = nn.Parameter(bias.unsqueeze(1))
  8. def forward(self, x):
  9. if x.dim() == 2:
  10. bias = self._bias.t().view(1, -1)
  11. else:
  12. bias = self._bias.t().view(1, -1, 1, 1)
  13. return x + bias
  14. # A temporary solution from the master branch.
  15. # https://github.com/pytorch/pytorch/blob/7752fe5d4e50052b3b0bbc9109e599f8157febc0/torch/nn/init.py#L312
  16. # Remove after the next version of PyTorch gets release.
  17. def orthogonal(tensor, gain=1):
  18. if tensor.ndimension() < 2:
  19. raise ValueError("Only tensors with 2 or more dimensions are supported")
  20. rows = tensor.size(0)
  21. cols = tensor[0].numel()
  22. flattened = torch.Tensor(rows, cols).normal_(0, 1)
  23. if rows < cols:
  24. flattened.t_()
  25. # Compute the qr factorization
  26. q, r = torch.qr(flattened)
  27. # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
  28. d = torch.diag(r, 0)
  29. ph = d.sign()
  30. q *= ph.expand_as(q)
  31. if rows < cols:
  32. q.t_()
  33. tensor.view_as(q).copy_(q)
  34. tensor.mul_(gain)
  35. return tensor