diff --git a/codes/models/switched_conv.py b/codes/models/switched_conv.py index 91315e44..20c6925d 100644 --- a/codes/models/switched_conv.py +++ b/codes/models/switched_conv.py @@ -4,6 +4,7 @@ from collections import OrderedDict import torch import torch.nn as nn +from lambda_networks import LambdaLayer from torch.nn import init, Conv2d import torch.nn.functional as F @@ -21,6 +22,7 @@ class SwitchedConv(nn.Module): bias: bool = True, padding_mode: str = 'zeros', include_coupler: bool = False, # A 'coupler' is a latent converter which can make any bxcxhxw tensor a compatible switchedconv selector by performing a linear 1x1 conv, softmax and interpolate. + coupler_mode: str = 'standard', coupler_dim_in: int = 0): super().__init__() self.in_channels = in_channels @@ -33,7 +35,11 @@ class SwitchedConv(nn.Module): self.groups = groups if include_coupler: - self.coupler = Conv2d(coupler_dim_in, switch_breadth, kernel_size=1) + if coupler_mode == 'standard': + self.coupler = Conv2d(coupler_dim_in, switch_breadth, kernel_size=1) + elif coupler_mode == 'lambda': + self.coupler = LambdaLayer(dim=coupler_dim_in, dim_out=switch_breadth, r=23, dim_k=16, heads=2, dim_u=1) + else: self.coupler = None @@ -52,12 +58,15 @@ class SwitchedConv(nn.Module): bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound) - def forward(self, inp, selector): + def forward(self, inp, selector=None): if self.coupler: + if selector is None: # A coupler can convert from any input to a selector, so 'None' is allowed. + selector = inp selector = F.softmax(self.coupler(selector), dim=1) out_shape = [s // self.stride for s in inp.shape[2:]] if selector.shape[2] != out_shape[0] or selector.shape[3] != out_shape[1]: selector = F.interpolate(selector, size=out_shape, mode="nearest") + assert selector is not None conv_results = [] for i, w in enumerate(self.weights): diff --git a/codes/models/vqvae/weighted_conv_vqvae.py b/codes/models/vqvae/vqvae_no_conv_transpose_switched_lambda.py similarity index 58% rename from codes/models/vqvae/weighted_conv_vqvae.py rename to codes/models/vqvae/vqvae_no_conv_transpose_switched_lambda.py index 9452ed26..9037f24f 100644 --- a/codes/models/vqvae/weighted_conv_vqvae.py +++ b/codes/models/vqvae/vqvae_no_conv_transpose_switched_lambda.py @@ -1,29 +1,25 @@ -# Copyright 2018 The Sonnet Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - import torch from torch import nn from torch.nn import functional as F import torch.distributed as distributed -from models.vqvae.scaled_weight_conv import ScaledWeightConv, ScaledWeightConvTranspose +from models.switched_conv import SwitchedConv, convert_conv_net_state_dict_to_switched_conv from trainer.networks import register_model from utils.util import checkpoint, opt_get +# Upsamples and blurs (similar to StyleGAN). Replaces ConvTranspose2D from the original paper. +class UpsampleConv(nn.Module): + def __init__(self, in_filters, out_filters, breadth, kernel_size, padding): + super().__init__() + self.conv = SwitchedConv(in_filters, out_filters, kernel_size, breadth, padding=padding, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_filters) + + def forward(self, x): + up = torch.nn.functional.interpolate(x, scale_factor=2) + return self.conv(up) + + class Quantize(nn.Module): def __init__(self, dim, n_embed, decay=0.99, eps=1e-5): super().__init__() @@ -82,20 +78,15 @@ class ResBlock(nn.Module): def __init__(self, in_channel, channel, breadth): super().__init__() - self.conv = nn.ModuleList([ + self.conv = nn.Sequential( nn.ReLU(inplace=True), - ScaledWeightConv(in_channel, channel, 3, padding=1, breadth=breadth), + SwitchedConv(in_channel, channel, 3, breadth, padding=1, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel), nn.ReLU(inplace=True), - ScaledWeightConv(channel, in_channel, 1, breadth=breadth), - ]) + SwitchedConv(channel, in_channel, 1, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel), + ) - def forward(self, input, masks): - out = input - for m in self.conv: - if isinstance(m, ScaledWeightConv): - out = m(out, masks) - else: - out = m(out) + def forward(self, input): + out = self.conv(input) out += input return out @@ -107,34 +98,29 @@ class Encoder(nn.Module): if stride == 4: blocks = [ - ScaledWeightConv(in_channel, channel // 2, 4, stride=2, padding=1, breadth=breadth), + SwitchedConv(in_channel, channel // 2, 5, breadth, stride=2, padding=2, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel), nn.ReLU(inplace=True), - ScaledWeightConv(channel // 2, channel, 4, stride=2, padding=1, breadth=breadth), + SwitchedConv(channel // 2, channel, 5, breadth, stride=2, padding=2, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel // 2), nn.ReLU(inplace=True), - ScaledWeightConv(channel, channel, 3, padding=1, breadth=breadth), + SwitchedConv(channel, channel, 3, breadth, padding=1, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel), ] elif stride == 2: blocks = [ - ScaledWeightConv(in_channel, channel // 2, 4, stride=2, padding=1, breadth=breadth), + SwitchedConv(in_channel, channel // 2, 5, breadth, stride=2, padding=2, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel), nn.ReLU(inplace=True), - ScaledWeightConv(channel // 2, channel, 3, padding=1, breadth=breadth), + SwitchedConv(channel // 2, channel, 3, breadth, padding=1, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel // 2), ] for i in range(n_res_block): - blocks.append(ResBlock(channel, n_res_channel, breadth=breadth)) + blocks.append(ResBlock(channel, n_res_channel, breadth)) blocks.append(nn.ReLU(inplace=True)) - self.blocks = nn.ModuleList(blocks) + self.blocks = nn.Sequential(*blocks) def forward(self, input): - for block in self.blocks: - if isinstance(block, ScaledWeightConv) or isinstance(block, ResBlock): - input = block(input, self.masks) - else: - input = block(input) - return input + return self.blocks(input) class Decoder(nn.Module): @@ -143,39 +129,33 @@ class Decoder(nn.Module): ): super().__init__() - blocks = [ScaledWeightConv(in_channel, channel, 3, padding=1, breadth=breadth)] + blocks = [SwitchedConv(in_channel, channel, 3, breadth, padding=1, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel)] for i in range(n_res_block): - blocks.append(ResBlock(channel, n_res_channel, breadth=breadth)) + blocks.append(ResBlock(channel, n_res_channel, breadth)) blocks.append(nn.ReLU(inplace=True)) if stride == 4: blocks.extend( [ - ScaledWeightConvTranspose(channel, channel // 2, 4, stride=2, padding=1, breadth=breadth), + UpsampleConv(channel, channel // 2, breadth, 5, padding=2), nn.ReLU(inplace=True), - ScaledWeightConvTranspose( - channel // 2, out_channel, 4, stride=2, padding=1, breadth=breadth + UpsampleConv( + channel // 2, out_channel, breadth, 5, padding=2 ), ] ) elif stride == 2: blocks.append( - ScaledWeightConvTranspose(channel, out_channel, 4, stride=2, padding=1, breadth=breadth) + UpsampleConv(channel, out_channel, breadth, 5, padding=2) ) - self.blocks = nn.ModuleList(blocks) + self.blocks = nn.Sequential(*blocks) def forward(self, input): - for block in self.blocks: - if isinstance(block, ScaledWeightConvTranspose) or isinstance(block, ResBlock) \ - or isinstance(block, ScaledWeightConv): - input = block(input, self.masks) - else: - input = block(input) - return input + return self.blocks(input) class VQVAE(nn.Module): @@ -187,22 +167,22 @@ class VQVAE(nn.Module): n_res_channel=32, codebook_dim=64, codebook_size=512, - breadth=8, decay=0.99, + breadth=4, ): super().__init__() self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4, breadth=breadth) self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2, breadth=breadth) - self.quantize_conv_t = ScaledWeightConv(channel, codebook_dim, 1, breadth=breadth) + self.quantize_conv_t = nn.Conv2d(channel, codebook_dim, 1) self.quantize_t = Quantize(codebook_dim, codebook_size) self.dec_t = Decoder( codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2, breadth=breadth ) - self.quantize_conv_b = ScaledWeightConv(codebook_dim + channel, codebook_dim, 1, breadth=breadth) - self.quantize_b = Quantize(codebook_dim, codebook_size) - self.upsample_t = ScaledWeightConvTranspose( - codebook_dim, codebook_dim, 4, stride=2, padding=1, breadth=breadth + self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1) + self.quantize_b = Quantize(codebook_dim, codebook_size*2) + self.upsample_t = UpsampleConv( + codebook_dim, codebook_dim, breadth, 5, padding=2 ) self.dec = Decoder( codebook_dim + codebook_dim, @@ -214,21 +194,17 @@ class VQVAE(nn.Module): breadth=breadth ) - def forward(self, input, masks): - # This awkward injection point is necessary to enable checkpointing to work. - for m in [self.enc_b, self.enc_t, self.dec_t, self.dec]: - m.masks = masks - - quant_t, quant_b, diff, _, _ = self.encode(input, masks) - dec = self.decode(quant_t, quant_b, masks) + def forward(self, input): + quant_t, quant_b, diff, _, _ = self.encode(input) + dec = self.decode(quant_t, quant_b) return dec, diff - def encode(self, input, masks): + def encode(self, input): enc_b = checkpoint(self.enc_b, input) enc_t = checkpoint(self.enc_t, enc_b) - quant_t = self.quantize_conv_t(enc_t, masks).permute(0, 2, 3, 1) + quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1) quant_t, diff_t, id_t = self.quantize_t(quant_t) quant_t = quant_t.permute(0, 3, 1, 2) diff_t = diff_t.unsqueeze(0) @@ -236,15 +212,15 @@ class VQVAE(nn.Module): dec_t = checkpoint(self.dec_t, quant_t) enc_b = torch.cat([dec_t, enc_b], 1) - quant_b = self.quantize_conv_b(enc_b, masks).permute(0, 2, 3, 1) + quant_b = checkpoint(self.quantize_conv_b, enc_b).permute(0, 2, 3, 1) quant_b, diff_b, id_b = self.quantize_b(quant_b) quant_b = quant_b.permute(0, 3, 1, 2) diff_b = diff_b.unsqueeze(0) return quant_t, quant_b, diff_t + diff_b, id_t, id_b - def decode(self, quant_t, quant_b, masks): - upsample_t = self.upsample_t(quant_t, masks) + def decode(self, quant_t, quant_b): + upsample_t = self.upsample_t(quant_t) quant = torch.cat([upsample_t, quant_b], 1) dec = checkpoint(self.dec, quant) @@ -256,12 +232,27 @@ class VQVAE(nn.Module): quant_b = self.quantize_b.embed_code(code_b) quant_b = quant_b.permute(0, 3, 1, 2) - dec = self.decode(quant_t, quant_b, masks) + dec = self.decode(quant_t, quant_b) return dec +def convert_weights(weights_file): + sd = torch.load(weights_file) + import models.vqvae.vqvae_no_conv_transpose as stdvq + std_model = stdvq.VQVAE() + std_model.load_state_dict(sd) + nsd = convert_conv_net_state_dict_to_switched_conv(std_model, 4, ['quantize_conv_t', 'quantize_conv_b']) + torch.save(nsd, "converted.pth") + + @register_model -def register_weighted_vqvae(opt_net, opt): +def register_vqvae_norm_switched_conv_lambda(opt_net, opt): kw = opt_get(opt_net, ['kwargs'], {}) return VQVAE(**kw) + + +if __name__ == '__main__': + #v = VQVAE() + #print(v(torch.randn(1,3,128,128))[0].shape) + convert_weights("../../../experiments/4000_generator.pth")