DL-Art-School/codes/models/vqvae/scaled_weight_conv.py
2021-01-11 20:09:16 -07:00

172 lines
7.1 KiB
Python

from typing import Optional, List
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.modules.conv import _ConvNd, _ConvTransposeNd
from torch.nn.modules.utils import _ntuple
import torch.nn.functional as F
_pair = _ntuple(2)
# Indexes the <p> index of input=b,c,h,w,p by the long tensor index=b,1,h,w. Result is b,c,h,w.
# Frankly - IMO - this is what torch.gather should do.
def index_2d(input, index):
index = index.repeat(1,input.shape[1],1,1)
e = torch.eye(input.shape[-1], device=input.device)
result = e[index] * input
return result.sum(-1)
# Drop-in implementation of Conv2d that can apply masked scales&shifts to the convolution weights.
class ScaledWeightConv(_ConvNd):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size,
stride = 1,
padding = 0,
dilation = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
breadth: int = 8,
):
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
super().__init__(
in_channels, out_channels, _pair(kernel_size), stride, padding, dilation,
False, _pair(0), groups, bias, padding_mode)
self.weight_scales = nn.ParameterList([nn.Parameter(torch.ones(out_channels, in_channels, kernel_size, kernel_size)) for _ in range(breadth)])
self.shifts = nn.ParameterList([nn.Parameter(torch.zeros(out_channels, in_channels, kernel_size, kernel_size)) for _ in range(breadth)])
for w, s in zip(self.weight_scales, self.shifts):
w.FOR_SCALE_SHIFT = True
s.FOR_SCALE_SHIFT = True
# This should probably be configurable at some point.
self.weight.DO_NOT_TRAIN = True
self.weight.requires_grad = False
def _weighted_conv_forward(self, input, weight):
if self.padding_mode != 'zeros':
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
weight, self.bias, self.stride,
_pair(0), self.dilation, self.groups)
return F.conv2d(input, weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
def forward(self, input: Tensor, masks: dict = None) -> Tensor:
if masks is None:
# An alternate "mode" of operation is the masks are injected as parameters.
assert hasattr(self, 'masks')
masks = self.masks
# This is an exceptionally inefficient way of achieving this functionality. The hope is that if this is any
# good at all, this can be made more efficient by performing a single conv pass with multiple masks.
weighted_convs = [self._weighted_conv_forward(input, self.weight * scale + shift) for scale, shift in zip(self.weight_scales, self.shifts)]
weighted_convs = torch.stack(weighted_convs, dim=-1)
needed_mask = weighted_convs.shape[-2]
assert needed_mask in masks.keys()
return index_2d(weighted_convs, masks[needed_mask])
def create_wrapped_conv_from_template(conv: nn.Conv2d, breadth: int):
wrapped = ScaledWeightConv(conv.in_channels,
conv.out_channels,
conv.kernel_size[0],
conv.stride[0],
conv.padding[0],
conv.dilation[0],
conv.groups,
conv.bias,
conv.padding_mode,
breadth)
return wrapped
# Drop-in implementation of ConvTranspose2d that can apply masked scales&shifts to the convolution weights.
class ScaledWeightConvTranspose(_ConvTransposeNd):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size,
stride = 1,
padding = 0,
output_padding = 0,
groups: int = 1,
bias: bool = True,
dilation: int = 1,
padding_mode: str = 'zeros',
breadth: int = 8,
):
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
output_padding = _pair(output_padding)
super().__init__(
in_channels, out_channels, _pair(kernel_size), stride, padding, dilation,
True, output_padding, groups, bias, padding_mode)
self.weight_scales = nn.ParameterList([nn.Parameter(torch.ones(in_channels, out_channels, kernel_size, kernel_size)) for _ in range(breadth)])
self.shifts = nn.ParameterList([nn.Parameter(torch.zeros(in_channels, out_channels, kernel_size, kernel_size)) for _ in range(breadth)])
for w, s in zip(self.weight_scales, self.shifts):
w.FOR_SCALE_SHIFT = True
s.FOR_SCALE_SHIFT = True
# This should probably be configurable at some point.
self.weight.DO_NOT_TRAIN = True
self.weight.requires_grad = False
def _conv_transpose_forward(self, input, weight, output_size) -> Tensor:
if self.padding_mode != 'zeros':
raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d')
output_padding = self._output_padding(
input, output_size, self.stride, self.padding, self.kernel_size, self.dilation)
return F.conv_transpose2d(
input, weight, self.bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
def forward(self, input: Tensor, masks: dict = None, output_size: Optional[List[int]] = None) -> Tensor:
if masks is None:
# An alternate "mode" of operation is the masks are injected as parameters.
assert hasattr(self, 'masks')
masks = self.masks
# This is an exceptionally inefficient way of achieving this functionality. The hope is that if this is any
# good at all, this can be made more efficient by performing a single conv pass with multiple masks.
weighted_convs = [self._conv_transpose_forward(input, self.weight * scale + shift, output_size)
for scale, shift in zip(self.weight_scales, self.shifts)]
weighted_convs = torch.stack(weighted_convs, dim=-1)
needed_mask = weighted_convs.shape[-2]
assert needed_mask in masks.keys()
return index_2d(weighted_convs, masks[needed_mask])
def create_wrapped_conv_transpose_from_template(conv: nn.Conv2d, breadth: int):
wrapped = ScaledWeightConvTranspose(conv.in_channels,
conv.out_channels,
conv.kernel_size,
conv.stride,
conv.padding,
conv.output_padding,
conv.groups,
conv.bias,
conv.dilation,
conv.padding_mode,
breadth)
wrapped.weight = conv.weight
wrapped.weight.DO_NOT_TRAIN = True
wrapped.weight.requires_grad = False
wrapped.bias = conv.bias
return wrapped