1234567891011121314151617 |
- import torch
- import torch.nn as nn
- # Necessary for my KFAC implementation.
- class AddBias(nn.Module):
- def __init__(self, bias):
- super(AddBias, self).__init__()
- self._bias = nn.Parameter(bias.unsqueeze(1))
- def forward(self, x):
- if x.dim() == 2:
- bias = self._bias.t().view(1, -1)
- else:
- bias = self._bias.t().view(1, -1, 1, 1)
- return x + bias
|