131 lines
5.5 KiB
Python
131 lines
5.5 KiB
Python
|
from typing import Optional, List
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from torch import Tensor
|
||
|
from torch.nn.modules.conv import _ConvNd, _ConvTransposeNd
|
||
|
from torch.nn.modules.utils import _ntuple
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
|
||
|
_pair = _ntuple(2)
|
||
|
|
||
|
|
||
|
# Indexes the <p> index of input=b,c,h,w,p by the long tensor index=b,1,h,w. Result is b,c,h,w.
|
||
|
# Frankly - IMO - this is what torch.gather should do.
|
||
|
def index_2d(input, index):
|
||
|
index = index.repeat(1,input.shape[1],1,1)
|
||
|
e = torch.eye(input.shape[-1], device=input.device)
|
||
|
result = e[index] * input
|
||
|
return result.sum(-1)
|
||
|
|
||
|
|
||
|
# Drop-in implementation of Conv2d that can apply masked scales&shifts to the convolution weights.
|
||
|
class ScaledWeightConv(_ConvNd):
|
||
|
def __init__(
|
||
|
self,
|
||
|
in_channels: int,
|
||
|
out_channels: int,
|
||
|
kernel_size,
|
||
|
stride = 1,
|
||
|
padding = 0,
|
||
|
dilation = 1,
|
||
|
groups: int = 1,
|
||
|
bias: bool = True,
|
||
|
padding_mode: str = 'zeros',
|
||
|
breadth: int = 8,
|
||
|
):
|
||
|
stride = _pair(stride)
|
||
|
padding = _pair(padding)
|
||
|
dilation = _pair(dilation)
|
||
|
super().__init__(
|
||
|
in_channels, out_channels, _pair(kernel_size), stride, padding, dilation,
|
||
|
False, _pair(0), groups, bias, padding_mode)
|
||
|
|
||
|
self.weight_scales = nn.ParameterList([nn.Parameter(torch.ones(out_channels, in_channels, kernel_size, kernel_size)) for _ in range(breadth)])
|
||
|
self.shifts = nn.ParameterList([nn.Parameter(torch.zeros(out_channels, in_channels, kernel_size, kernel_size)) for _ in range(breadth)])
|
||
|
for w, s in zip(self.weight_scales, self.shifts):
|
||
|
w.FOR_SCALE_SHIFT = True
|
||
|
s.FOR_SCALE_SHIFT = True
|
||
|
# This should probably be configurable at some point.
|
||
|
for p in self.parameters():
|
||
|
if not hasattr(p, "FOR_SCALE_SHIFT"):
|
||
|
p.DO_NOT_TRAIN = True
|
||
|
|
||
|
def _weighted_conv_forward(self, input, weight):
|
||
|
if self.padding_mode != 'zeros':
|
||
|
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
|
||
|
weight, self.bias, self.stride,
|
||
|
_pair(0), self.dilation, self.groups)
|
||
|
return F.conv2d(input, weight, self.bias, self.stride,
|
||
|
self.padding, self.dilation, self.groups)
|
||
|
|
||
|
def forward(self, input: Tensor, masks: dict) -> Tensor:
|
||
|
# This is an exceptionally inefficient way of achieving this functionality. The hope is that if this is any
|
||
|
# good at all, this can be made more efficient by performing a single conv pass with multiple masks.
|
||
|
weighted_convs = [self._weighted_conv_forward(input, self.weight * scale + shift) for scale, shift in zip(self.weight_scales, self.shifts)]
|
||
|
weighted_convs = torch.stack(weighted_convs, dim=-1)
|
||
|
|
||
|
needed_mask = weighted_convs.shape[-2]
|
||
|
assert needed_mask in masks.keys()
|
||
|
|
||
|
return index_2d(weighted_convs, masks[needed_mask])
|
||
|
|
||
|
|
||
|
# Drop-in implementation of ConvTranspose2d that can apply masked scales&shifts to the convolution weights.
|
||
|
class ScaledWeightConvTranspose(_ConvTransposeNd):
|
||
|
def __init__(
|
||
|
self,
|
||
|
in_channels: int,
|
||
|
out_channels: int,
|
||
|
kernel_size,
|
||
|
stride = 1,
|
||
|
padding = 0,
|
||
|
output_padding = 0,
|
||
|
groups: int = 1,
|
||
|
bias: bool = True,
|
||
|
dilation: int = 1,
|
||
|
padding_mode: str = 'zeros',
|
||
|
breadth: int = 8,
|
||
|
):
|
||
|
stride = _pair(stride)
|
||
|
padding = _pair(padding)
|
||
|
dilation = _pair(dilation)
|
||
|
output_padding = _pair(output_padding)
|
||
|
super().__init__(
|
||
|
in_channels, out_channels, _pair(kernel_size), stride, padding, dilation,
|
||
|
True, output_padding, groups, bias, padding_mode)
|
||
|
|
||
|
self.weight_scales = nn.ParameterList([nn.Parameter(torch.ones(in_channels, out_channels, kernel_size, kernel_size)) for _ in range(breadth)])
|
||
|
self.shifts = nn.ParameterList([nn.Parameter(torch.zeros(in_channels, out_channels, kernel_size, kernel_size)) for _ in range(breadth)])
|
||
|
for w, s in zip(self.weight_scales, self.shifts):
|
||
|
w.FOR_SCALE_SHIFT = True
|
||
|
s.FOR_SCALE_SHIFT = True
|
||
|
# This should probably be configurable at some point.
|
||
|
for nm, p in self.named_parameters():
|
||
|
if nm == 'weight':
|
||
|
p.DO_NOT_TRAIN = True
|
||
|
|
||
|
def _conv_transpose_forward(self, input, weight, output_size) -> Tensor:
|
||
|
if self.padding_mode != 'zeros':
|
||
|
raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d')
|
||
|
|
||
|
output_padding = self._output_padding(
|
||
|
input, output_size, self.stride, self.padding, self.kernel_size, self.dilation)
|
||
|
|
||
|
return F.conv_transpose2d(
|
||
|
input, weight, self.bias, self.stride, self.padding,
|
||
|
output_padding, self.groups, self.dilation)
|
||
|
|
||
|
def forward(self, input: Tensor, masks: dict, output_size: Optional[List[int]] = None) -> Tensor:
|
||
|
# This is an exceptionally inefficient way of achieving this functionality. The hope is that if this is any
|
||
|
# good at all, this can be made more efficient by performing a single conv pass with multiple masks.
|
||
|
weighted_convs = [self._conv_transpose_forward(input, self.weight * scale + shift, output_size)
|
||
|
for scale, shift in zip(self.weight_scales, self.shifts)]
|
||
|
weighted_convs = torch.stack(weighted_convs, dim=-1)
|
||
|
|
||
|
needed_mask = weighted_convs.shape[-2]
|
||
|
assert needed_mask in masks.keys()
|
||
|
|
||
|
return index_2d(weighted_convs, masks[needed_mask])
|