From 96bc80313c4519407d7ec920a9282caeb40c6a70 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 26 Jan 2021 09:31:53 -0700 Subject: [PATCH] Add switch norm, up dropout rate, detach selector --- codes/models/switched_conv_hard_routing.py | 116 ++++++- ...e_no_conv_transpose_hardswitched_lambda.py | 293 ++++++++++++++++++ ...vqvae_no_conv_transpose_switched_lambda.py | 32 +- 3 files changed, 415 insertions(+), 26 deletions(-) create mode 100644 codes/models/vqvae/vqvae_no_conv_transpose_hardswitched_lambda.py diff --git a/codes/models/switched_conv_hard_routing.py b/codes/models/switched_conv_hard_routing.py index 0fafc654..a9320063 100644 --- a/codes/models/switched_conv_hard_routing.py +++ b/codes/models/switched_conv_hard_routing.py @@ -7,6 +7,7 @@ from lambda_networks import LambdaLayer from torch.nn import init, Conv2d, MSELoss import torch.nn.functional as F from tqdm import tqdm +import torch.distributed as dist class SwitchedConvHardRoutingFunction(torch.autograd.Function): @@ -37,11 +38,90 @@ class SwitchedConvHardRoutingFunction(torch.autograd.Function): return grad, grad_sel, grad_w, grad_b, None +""" +SwitchNorm is meant to be applied against the Softmax output of an switching function across a large set of +switch computations. It is meant to promote an equal distribution of switch weights by decreasing the magnitude +of switch weights that are over-used and increasing the magnitude of under-used weights. + +The return value has the exact same format as a normal Softmax output and can be used directly into the input of an +switch equation. + +Since the whole point of convolutional switch is to enable training extra-wide networks to operate on a large number +of image categories, it makes almost no sense to perform this type of norm against a single mini-batch of images: some +of the switches will not be used in such a small context - and that's good! This is solved by accumulating. Every +forward pass computes a norm across the current minibatch. That norm is added into a rotating buffer of size +. The actual normalization occurs across the entire rotating buffer. + +You should set accumulator size according to two factors: +- Your batch size. Smaller batch size should mean greater accumulator size. +- Your image diversity. More diverse images have less need for the accumulator. +- How wide your switch/switching group size is. More groups mean you're going to want more accumulation. + +Note: This norm makes the (potentially flawed) assumption that each forward() pass has unique data. For maximum + effectiveness, avoid doing this - or make alterations to work around it. +Note: This norm does nothing for the first iterations. +""" +class SwitchNorm(nn.Module): + def __init__(self, group_size, accumulator_size=128): + super().__init__() + self.accumulator_desired_size = accumulator_size + self.group_size = group_size + self.register_buffer("accumulator_index", torch.zeros(1, dtype=torch.long, device='cpu')) + self.register_buffer("accumulator_filled", torch.zeros(1, dtype=torch.long, device='cpu')) + self.register_buffer("accumulator", torch.zeros(accumulator_size, group_size)) + + def add_norm_to_buffer(self, x): + flat = x.sum(dim=[0, 2, 3]) + norm = flat / torch.mean(flat) + + self.accumulator[self.accumulator_index] = norm.detach().clone() + self.accumulator_index += 1 + if self.accumulator_index >= self.accumulator_desired_size: + self.accumulator_index *= 0 + if self.accumulator_filled <= 0: + self.accumulator_filled += 1 + + # Input into forward is a switching tensor of shape (batch,groups,width,height) + def forward(self, x: torch.Tensor, update_attention_norm=True): + assert len(x.shape) == 4 + + # Push the accumulator to the right device on the first iteration. + if self.accumulator.device != x.device: + self.accumulator = self.accumulator.to(x.device) + + # In eval, don't change the norm buffer. + if self.training and update_attention_norm: + self.add_norm_to_buffer(x) + + # Reduce across all distributed entities, if needed + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(self.accumulator, op=dist.ReduceOp.SUM) + self.accumulator /= dist.get_world_size() + + # Compute the norm factor. + if self.accumulator_filled > 0: + norm = torch.mean(self.accumulator, dim=0) + else: + norm = torch.ones(self.group_size, device=self.accumulator.device) + x = x / norm.view(1,-1,1,1) + + # Need to re-normalize x so that the groups dimension sum to 1, just like when it was fed in. + return x / x.sum(dim=1, keepdim=True) + + class SwitchedConvHardRouting(nn.Module): - def __init__(self, in_c, out_c, kernel_sz, breadth, stride=1, bias=True, dropout_rate=0.0, - include_coupler: bool = False, # A 'coupler' is a latent converter which can make any bxcxhxw tensor a compatible switchedconv selector by performing a linear 1x1 conv, softmax and interpolate. - coupler_mode: str = 'standard', - coupler_dim_in: int = 0,): + def __init__(self, + in_c, + out_c, + kernel_sz, + breadth, + stride=1, + bias=True, + dropout_rate=0.0, + include_coupler: bool = False, # A 'coupler' is a latent converter which can make any bxcxhxw tensor a compatible switchedconv selector by performing a linear 1x1 conv, softmax and interpolate. + coupler_mode: str = 'standard', + coupler_dim_in: int = 0, + switch_norm: bool = True): super().__init__() self.in_channels = in_c self.out_channels = out_c @@ -50,12 +130,22 @@ class SwitchedConvHardRouting(nn.Module): self.has_bias = bias self.breadth = breadth self.dropout_rate = dropout_rate + if switch_norm: + self.switch_norm = SwitchNorm(breadth, accumulator_size=512) + else: + self.switch_norm = None if include_coupler: if coupler_mode == 'standard': self.coupler = Conv2d(coupler_dim_in, breadth, kernel_size=1) elif coupler_mode == 'lambda': - self.coupler = LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1) + self.coupler = nn.Sequential(nn.Conv2d(coupler_dim_in, coupler_dim_in, 1), + nn.BatchNorm2d(coupler_dim_in), + nn.ReLU(), + LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1), + nn.BatchNorm2d(breadth), + nn.ReLU(), + Conv2d(breadth, breadth, 1)) else: self.coupler = None @@ -85,11 +175,14 @@ class SwitchedConvHardRouting(nn.Module): # If a coupler was specified, run that to convert selector into a softmax distribution. if self.coupler: if selector is None: # A coupler can convert from any input to a selector, so 'None' is allowed. - selector = input + selector = input.detach() selector = F.softmax(self.coupler(selector), dim=1) - self.last_select = selector.detach().clone() assert selector is not None + # Perform normalization on the selector if applicable. + if self.switch_norm: + selector = self.switch_norm(selector) + # Apply dropout at the batch level per kernel. if self.training and self.dropout_rate > 0: b, c, h, w = selector.shape @@ -99,6 +192,10 @@ class SwitchedConvHardRouting(nn.Module): drop = drop.logical_or(fix_blank) selector = drop * selector + # Debugging variables + self.last_select = selector.detach().clone() + self.latest_masks = (selector.max(dim=1, keepdim=True)[0].repeat(1,self.breadth,1,1) == selector).float().argmax(dim=1) + return SwitchedConvHardRoutingFunction.apply(input, selector, self.weight, self.bias, self.stride) @@ -107,6 +204,8 @@ class SwitchedConvHardRouting(nn.Module): def convert_conv_net_state_dict_to_switched_conv(module, switch_breadth, ignore_list=[]): state_dict = module.state_dict() for name, m in module.named_modules(): + if not isinstance(m, nn.Conv2d): + continue ignored = False for smod in ignore_list: if smod in name: @@ -114,8 +213,7 @@ def convert_conv_net_state_dict_to_switched_conv(module, switch_breadth, ignore_ continue if ignored: continue - if isinstance(m, nn.Conv2d): - state_dict[f'{name}.weight'] = state_dict[f'{name}.weight'].unsqueeze(2).repeat(1,1,switch_breadth,1,1) + state_dict[f'{name}.weight'] = state_dict[f'{name}.weight'].unsqueeze(2).repeat(1,1,switch_breadth,1,1) return state_dict diff --git a/codes/models/vqvae/vqvae_no_conv_transpose_hardswitched_lambda.py b/codes/models/vqvae/vqvae_no_conv_transpose_hardswitched_lambda.py new file mode 100644 index 00000000..06ab7045 --- /dev/null +++ b/codes/models/vqvae/vqvae_no_conv_transpose_hardswitched_lambda.py @@ -0,0 +1,293 @@ +import os + +import torch +import torchvision +from torch import nn +from torch.nn import functional as F + +import torch.distributed as distributed + +from models.switched_conv_hard_routing import SwitchedConvHardRouting, \ + convert_conv_net_state_dict_to_switched_conv +from trainer.networks import register_model +from utils.util import checkpoint, opt_get + + +# Upsamples and blurs (similar to StyleGAN). Replaces ConvTranspose2D from the original paper. +class UpsampleConv(nn.Module): + def __init__(self, in_filters, out_filters, breadth, kernel_size, padding): + super().__init__() + self.conv = SwitchedConvHardRouting(in_filters, out_filters, kernel_size, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_filters, dropout_rate=0.4) + + def forward(self, x): + up = torch.nn.functional.interpolate(x, scale_factor=2) + return self.conv(up) + + +class Quantize(nn.Module): + def __init__(self, dim, n_embed, decay=0.99, eps=1e-5): + super().__init__() + + self.dim = dim + self.n_embed = n_embed + self.decay = decay + self.eps = eps + + embed = torch.randn(dim, n_embed) + self.register_buffer("embed", embed) + self.register_buffer("cluster_size", torch.zeros(n_embed)) + self.register_buffer("embed_avg", embed.clone()) + + def forward(self, input): + flatten = input.reshape(-1, self.dim) + dist = ( + flatten.pow(2).sum(1, keepdim=True) + - 2 * flatten @ self.embed + + self.embed.pow(2).sum(0, keepdim=True) + ) + _, embed_ind = (-dist).max(1) + embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype) + embed_ind = embed_ind.view(*input.shape[:-1]) + quantize = self.embed_code(embed_ind) + + if self.training: + embed_onehot_sum = embed_onehot.sum(0) + embed_sum = flatten.transpose(0, 1) @ embed_onehot + + if distributed.is_initialized() and distributed.get_world_size() > 1: + distributed.all_reduce(embed_onehot_sum) + distributed.all_reduce(embed_sum) + + self.cluster_size.data.mul_(self.decay).add_( + embed_onehot_sum, alpha=1 - self.decay + ) + self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay) + n = self.cluster_size.sum() + cluster_size = ( + (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n + ) + embed_normalized = self.embed_avg / cluster_size.unsqueeze(0) + self.embed.data.copy_(embed_normalized) + + diff = (quantize.detach() - input).pow(2).mean() + quantize = input + (quantize - input).detach() + + return quantize, diff, embed_ind + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.embed.transpose(0, 1)) + + +class ResBlock(nn.Module): + def __init__(self, in_channel, channel, breadth): + super().__init__() + + self.conv = nn.Sequential( + nn.ReLU(inplace=True), + nn.Conv2d(in_channel, channel, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(channel, in_channel, 1), + ) + + def forward(self, input): + out = self.conv(input) + out += input + + return out + + +class Encoder(nn.Module): + def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride, breadth): + super().__init__() + + if stride == 4: + blocks = [ + nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2), + nn.ReLU(inplace=True), + SwitchedConvHardRouting(channel // 2, channel, 5, breadth, stride=2, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel // 2, dropout_rate=0.4), + nn.ReLU(inplace=True), + SwitchedConvHardRouting(channel, channel, 3, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel, dropout_rate=0.4), + ] + + elif stride == 2: + blocks = [ + nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2), + nn.ReLU(inplace=True), + SwitchedConvHardRouting(channel // 2, channel, 3, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel // 2, dropout_rate=0.4), + ] + + for i in range(n_res_block): + blocks.append(ResBlock(channel, n_res_channel, breadth)) + + blocks.append(nn.ReLU(inplace=True)) + + self.blocks = nn.Sequential(*blocks) + + def forward(self, input): + return self.blocks(input) + + +class Decoder(nn.Module): + def __init__( + self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride, breadth + ): + super().__init__() + + blocks = [SwitchedConvHardRouting(in_channel, channel, 3, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel, dropout_rate=0.4)] + + for i in range(n_res_block): + blocks.append(ResBlock(channel, n_res_channel, breadth)) + + blocks.append(nn.ReLU(inplace=True)) + + if stride == 4: + blocks.extend( + [ + UpsampleConv(channel, channel // 2, breadth, 5, padding=2), + nn.ReLU(inplace=True), + UpsampleConv( + channel // 2, out_channel, breadth, 5, padding=2 + ), + ] + ) + + elif stride == 2: + blocks.append( + UpsampleConv(channel, out_channel, breadth, 5, padding=2) + ) + + self.blocks = nn.Sequential(*blocks) + + def forward(self, input): + return self.blocks(input) + + +class VQVAE(nn.Module): + def __init__( + self, + in_channel=3, + channel=128, + n_res_block=2, + n_res_channel=32, + codebook_dim=64, + codebook_size=512, + decay=0.99, + breadth=8, + ): + super().__init__() + + self.breadth = breadth + self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4, breadth=breadth) + self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2, breadth=breadth) + self.quantize_conv_t = nn.Conv2d(channel, codebook_dim, 1) + self.quantize_t = Quantize(codebook_dim, codebook_size) + self.dec_t = Decoder( + codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2, breadth=breadth + ) + self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1) + self.quantize_b = Quantize(codebook_dim, codebook_size*2) + self.upsample_t = UpsampleConv( + codebook_dim, codebook_dim, breadth, 5, padding=2 + ) + self.dec = Decoder( + codebook_dim + codebook_dim, + in_channel, + channel, + n_res_block, + n_res_channel, + stride=4, + breadth=breadth + ) + + def forward(self, input): + quant_t, quant_b, diff, _, _ = self.encode(input) + dec = self.decode(quant_t, quant_b) + + return dec, diff + + def save_attention_to_image_rgb(self, output_file, attention_out, attention_size, cmap_discrete_name='viridis'): + from matplotlib import cm + magnitude, indices = torch.topk(attention_out, 3, dim=1) + indices = indices.cpu() + colormap = cm.get_cmap(cmap_discrete_name, attention_size) + img = torch.tensor(colormap(indices[:, 0, :, :].detach().numpy())) # TODO: use other k's + img = img.permute((0, 3, 1, 2)) + torchvision.utils.save_image(img, output_file) + + def visual_dbg(self, step, path): + convs = [self.dec.blocks[-1].conv, self.dec_t.blocks[-1].conv, self.enc_b.blocks[-4], self.enc_t.blocks[-4]] + for i, c in enumerate(convs): + self.save_attention_to_image_rgb(os.path.join(path, "%i_selector_%i.png" % (step, i+1)), c.last_select, self.breadth) + + def get_debug_values(self, step, __): + switched_convs = [('enc_b_blk2', self.enc_b.blocks[2]), + ('enc_b_blk4', self.enc_b.blocks[4]), + ('enc_t_blk2', self.enc_t.blocks[2]), + ('dec_t_blk0', self.dec_t.blocks[0]), + ('dec_t_blk-1', self.dec_t.blocks[-1].conv), + ('dec_blk0', self.dec.blocks[0]), + ('dec_blk-1', self.dec.blocks[-1].conv), + ('dec_blk-3', self.dec.blocks[-3].conv)] + logs = {} + for name, swc in switched_convs: + logs[f'{name}_histogram_switch_usage'] = swc.latest_masks + return logs + + def encode(self, input): + enc_b = checkpoint(self.enc_b, input) + enc_t = checkpoint(self.enc_t, enc_b) + + quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1) + quant_t, diff_t, id_t = self.quantize_t(quant_t) + quant_t = quant_t.permute(0, 3, 1, 2) + diff_t = diff_t.unsqueeze(0) + + dec_t = checkpoint(self.dec_t, quant_t) + enc_b = torch.cat([dec_t, enc_b], 1) + + quant_b = checkpoint(self.quantize_conv_b, enc_b).permute(0, 2, 3, 1) + quant_b, diff_b, id_b = self.quantize_b(quant_b) + quant_b = quant_b.permute(0, 3, 1, 2) + diff_b = diff_b.unsqueeze(0) + + return quant_t, quant_b, diff_t + diff_b, id_t, id_b + + def decode(self, quant_t, quant_b): + upsample_t = self.upsample_t(quant_t) + quant = torch.cat([upsample_t, quant_b], 1) + dec = checkpoint(self.dec, quant) + + return dec + + def decode_code(self, code_t, code_b): + quant_t = self.quantize_t.embed_code(code_t) + quant_t = quant_t.permute(0, 3, 1, 2) + quant_b = self.quantize_b.embed_code(code_b) + quant_b = quant_b.permute(0, 3, 1, 2) + + dec = self.decode(quant_t, quant_b) + + return dec + + +def convert_weights(weights_file): + sd = torch.load(weights_file) + import models.vqvae.vqvae_no_conv_transpose as stdvq + std_model = stdvq.VQVAE() + std_model.load_state_dict(sd) + nsd = convert_conv_net_state_dict_to_switched_conv(std_model, 8, ['quantize_conv_t', 'quantize_conv_b', + 'enc_b.blocks.0', 'enc_t.blocks.0', + 'conv.1', 'conv.3']) + torch.save(nsd, "converted.pth") + + +@register_model +def register_vqvae_norm_hard_switched_conv_lambda(opt_net, opt): + kw = opt_get(opt_net, ['kwargs'], {}) + return VQVAE(**kw) + + +if __name__ == '__main__': + v = VQVAE(breadth=8).cuda() + print(v(torch.randn(1,3,128,128).cuda())[0].shape) + #convert_weights("../../../experiments/50000_generator.pth") diff --git a/codes/models/vqvae/vqvae_no_conv_transpose_switched_lambda.py b/codes/models/vqvae/vqvae_no_conv_transpose_switched_lambda.py index c01d9716..84c5170a 100644 --- a/codes/models/vqvae/vqvae_no_conv_transpose_switched_lambda.py +++ b/codes/models/vqvae/vqvae_no_conv_transpose_switched_lambda.py @@ -7,8 +7,7 @@ from torch.nn import functional as F import torch.distributed as distributed -from models.switched_conv_hard_routing import SwitchedConvHardRouting, \ - convert_conv_net_state_dict_to_switched_conv +from models.switched_conv import SwitchedConv, convert_conv_net_state_dict_to_switched_conv from trainer.networks import register_model from utils.util import checkpoint, opt_get @@ -17,7 +16,7 @@ from utils.util import checkpoint, opt_get class UpsampleConv(nn.Module): def __init__(self, in_filters, out_filters, breadth, kernel_size, padding): super().__init__() - self.conv = SwitchedConvHardRouting(in_filters, out_filters, kernel_size, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_filters, dropout_rate=0.2) + self.conv = SwitchedConv(in_filters, out_filters, kernel_size, breadth, padding=padding, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_filters) def forward(self, x): up = torch.nn.functional.interpolate(x, scale_factor=2) @@ -84,9 +83,9 @@ class ResBlock(nn.Module): self.conv = nn.Sequential( nn.ReLU(inplace=True), - nn.Conv2d(in_channel, channel, 3, padding=1), + SwitchedConv(in_channel, channel, 3, breadth, padding=1, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel), nn.ReLU(inplace=True), - nn.Conv2d(channel, in_channel, 1), + SwitchedConv(channel, in_channel, 1, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel), ) def forward(self, input): @@ -102,18 +101,18 @@ class Encoder(nn.Module): if stride == 4: blocks = [ - SwitchedConvHardRouting(in_channel, channel // 2, 5, breadth, stride=2, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel, dropout_rate=0.2), + SwitchedConv(in_channel, channel // 2, 5, breadth, stride=2, padding=2, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel), nn.ReLU(inplace=True), - SwitchedConvHardRouting(channel // 2, channel, 5, breadth, stride=2, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel // 2, dropout_rate=0.2), + SwitchedConv(channel // 2, channel, 5, breadth, stride=2, padding=2, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel // 2), nn.ReLU(inplace=True), - SwitchedConvHardRouting(channel, channel, 3, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel, dropout_rate=0.2), + SwitchedConv(channel, channel, 3, breadth, padding=1, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel), ] elif stride == 2: blocks = [ - SwitchedConvHardRouting(in_channel, channel // 2, 5, breadth, stride=2, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel, dropout_rate=0.2), + SwitchedConv(in_channel, channel // 2, 5, breadth, stride=2, padding=2, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel), nn.ReLU(inplace=True), - SwitchedConvHardRouting(channel // 2, channel, 3, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel // 2, dropout_rate=0.2), + SwitchedConv(channel // 2, channel, 3, breadth, padding=1, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel // 2), ] for i in range(n_res_block): @@ -133,7 +132,7 @@ class Decoder(nn.Module): ): super().__init__() - blocks = [SwitchedConvHardRouting(in_channel, channel, 3, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel, dropout_rate=0.2)] + blocks = [SwitchedConv(in_channel, channel, 3, breadth, padding=1, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel)] for i in range(n_res_block): blocks.append(ResBlock(channel, n_res_channel, breadth)) @@ -172,7 +171,7 @@ class VQVAE(nn.Module): codebook_dim=64, codebook_size=512, decay=0.99, - breadth=8, + breadth=4, ): super().__init__() @@ -261,8 +260,7 @@ def convert_weights(weights_file): import models.vqvae.vqvae_no_conv_transpose as stdvq std_model = stdvq.VQVAE() std_model.load_state_dict(sd) - nsd = convert_conv_net_state_dict_to_switched_conv(std_model, 1, ['quantize_conv_t', 'quantize_conv_b', - 'conv.1', 'conv.3']) + nsd = convert_conv_net_state_dict_to_switched_conv(std_model, 4, ['quantize_conv_t', 'quantize_conv_b']) torch.save(nsd, "converted.pth") @@ -273,6 +271,6 @@ def register_vqvae_norm_switched_conv_lambda(opt_net, opt): if __name__ == '__main__': - v = VQVAE(breadth=8).cuda() - print(v(torch.randn(1,3,128,128).cuda())[0].shape) - #convert_weights("../../../experiments/50000_generator.pth") + #v = VQVAE() + #print(v(torch.randn(1,3,128,128))[0].shape) + convert_weights("../../../experiments/4000_generator.pth")