12345678910111213141516171819202122232425262728293031323334353637383940414243444546 |
- import torch
- import torch.nn as nn
- class AddBias(nn.Module):
- def __init__(self, bias):
- super(AddBias, self).__init__()
- self._bias = nn.Parameter(bias.unsqueeze(1))
- def forward(self, x):
- if x.dim() == 2:
- bias = self._bias.t().view(1, -1)
- else:
- bias = self._bias.t().view(1, -1, 1, 1)
- return x + bias
- def orthogonal(tensor, gain=1):
- if tensor.ndimension() < 2:
- raise ValueError("Only tensors with 2 or more dimensions are supported")
- rows = tensor.size(0)
- cols = tensor[0].numel()
- flattened = torch.Tensor(rows, cols).normal_(0, 1)
- if rows < cols:
- flattened.t_()
-
- q, r = torch.qr(flattened)
-
- d = torch.diag(r, 0)
- ph = d.sign()
- q *= ph.expand_as(q)
- if rows < cols:
- q.t_()
- tensor.view_as(q).copy_(q)
- tensor.mul_(gain)
- return tensor
|