From f2a31702b5d7aab06158507c57bfa8d1e86f2c1c Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 20 Oct 2021 21:19:25 -0600 Subject: [PATCH] Clean stuff up, move more things into arch_util --- codes/models/arch_util.py | 44 +- codes/models/diffusion/diffusion_dvae.py | 387 ------------------ codes/models/vqvae/vqvae_3.py | 179 -------- codes/models/vqvae/vqvae_3_hardswitch.py | 279 ------------- .../models/vqvae/vqvae_3_separated_coupler.py | 179 -------- codes/models/vqvae/vqvae_audio_xformer.py | 149 ------- codes/models/vqvae/vqvae_no_conv_transpose.py | 265 ------------ ...e_no_conv_transpose_hardswitched_lambda.py | 293 ------------- ...vqvae_no_conv_transpose_switched_lambda.py | 276 ------------- .../diffusion/diffusion_noise_surfer.py | 2 +- codes/train.py | 2 +- 11 files changed, 37 insertions(+), 2018 deletions(-) delete mode 100644 codes/models/diffusion/diffusion_dvae.py delete mode 100644 codes/models/vqvae/vqvae_3.py delete mode 100644 codes/models/vqvae/vqvae_3_hardswitch.py delete mode 100644 codes/models/vqvae/vqvae_3_separated_coupler.py delete mode 100644 codes/models/vqvae/vqvae_audio_xformer.py delete mode 100644 codes/models/vqvae/vqvae_no_conv_transpose.py delete mode 100644 codes/models/vqvae/vqvae_no_conv_transpose_hardswitched_lambda.py delete mode 100644 codes/models/vqvae/vqvae_no_conv_transpose_switched_lambda.py diff --git a/codes/models/arch_util.py b/codes/models/arch_util.py index 2b171fe8..227b5eff 100644 --- a/codes/models/arch_util.py +++ b/codes/models/arch_util.py @@ -8,6 +8,38 @@ import torch.nn.functional as F import torch.nn.utils.spectral_norm as SpectralNorm from math import sqrt + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +def l2norm(t): + return F.normalize(t, p = 2, dim = -1) + + +def ema_inplace(moving_avg, new, decay): + moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay)) + + +def laplace_smoothing(x, n_categories, eps = 1e-5): + return (x + eps) / (x.sum() + n_categories * eps) + + +def sample_vectors(samples, num): + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device = device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device = device) + + return samples[indices] + + def kaiming_init(module, a=0, mode='fan_out', @@ -24,9 +56,11 @@ def kaiming_init(module, if hasattr(module, 'bias') and module.bias is not None: nn.init.constant_(module.bias, bias) + def pixel_norm(x, epsilon=1e-8): return x * torch.rsqrt(torch.mean(torch.pow(x, 2), dim=1, keepdims=True) + epsilon) + def initialize_weights(net_l, scale=1): if not isinstance(net_l, list): net_l = [net_l] @@ -75,20 +109,12 @@ def default_init_weights(module, scale=1): elif isinstance(m, nn.Linear): kaiming_init(m, a=0, mode='fan_in', bias=0) m.weight.data *= scale -""" -Various utilities for neural networks. -""" - -import math - -import torch as th -import torch.nn as nn # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. class SiLU(nn.Module): def forward(self, x): - return x * th.sigmoid(x) + return x * torch.sigmoid(x) class GroupNorm32(nn.GroupNorm): diff --git a/codes/models/diffusion/diffusion_dvae.py b/codes/models/diffusion/diffusion_dvae.py deleted file mode 100644 index 305f7879..00000000 --- a/codes/models/diffusion/diffusion_dvae.py +++ /dev/null @@ -1,387 +0,0 @@ -from models.diffusion.fp16_util import convert_module_to_f32, convert_module_to_f16 -from models.diffusion.gaussian_diffusion import get_named_beta_schedule -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.respace import SpacedDiffusion, space_timesteps -from models.diffusion.unet_diffusion import AttentionPool2d, AttentionBlock, ResBlock, TimestepEmbedSequential, \ - Downsample, Upsample -import torch -import torch.nn as nn - -from models.gpt_voice.lucidrains_dvae import eval_decorator -from models.gpt_voice.mini_encoder import AudioMiniEncoder, EmbeddingCombiner -from models.vqvae.gumbel_quantizer import GumbelQuantizer -from models.vqvae.vqvae import Quantize -from trainer.networks import register_model -from utils.util import get_mask_from_lengths -import models.gpt_voice.mini_encoder as menc - - -class DiscreteEncoder(nn.Module): - def __init__(self, - in_channels, - model_channels, - out_channels, - dropout, - scale): - super().__init__() - self.blocks = nn.Sequential( - conv_nd(1, in_channels, model_channels, 3, padding=1), - menc.ResBlock(model_channels, dropout, dims=1), - Downsample(model_channels, use_conv=True, dims=1, out_channels=model_channels*2, factor=scale), - menc.ResBlock(model_channels*2, dropout, dims=1), - Downsample(model_channels*2, use_conv=True, dims=1, out_channels=model_channels*4, factor=scale), - menc.ResBlock(model_channels*4, dropout, dims=1), - AttentionBlock(model_channels*4, num_heads=4), - menc.ResBlock(model_channels*4, dropout, out_channels=out_channels, dims=1), - ) - - def forward(self, spectrogram): - return self.blocks(spectrogram) - - -class DiscreteDecoder(nn.Module): - def __init__(self, in_channels, level_channels, scale): - super().__init__() - # Just raw upsampling, return a dict with each layer. - self.init = conv_nd(1, in_channels, level_channels[0], kernel_size=3, padding=1) - layers = [] - for i, lvl in enumerate(level_channels[:-1]): - layers.append(nn.Sequential(normalization(lvl), - nn.SiLU(lvl), - Upsample(lvl, use_conv=True, dims=1, out_channels=level_channels[i+1], factor=scale))) - self.layers = nn.ModuleList(layers) - - def forward(self, x): - y = self.init(x) - outs = [y] - for layer in self.layers: - y = layer(y) - outs.append(y) - return outs - - -class DiffusionDVAE(nn.Module): - def __init__( - self, - model_channels, - num_res_blocks, - in_channels=1, - out_channels=2, # mean and variance - spectrogram_channels=80, - spectrogram_conditioning_levels=[3,4,5], # Levels at which spectrogram conditioning is applied to the waveform. - dropout=0, - channel_mult=(1, 2, 4, 8, 16, 32, 64), - attention_resolutions=(16,32,64), - conv_resample=True, - dims=1, - use_fp16=False, - num_heads=1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - use_new_attention_order=False, - kernel_size=5, - quantize_dim=1024, - num_discrete_codes=8192, - scale_steps=4, - conditioning_inputs_provided=True, - ): - super().__init__() - - if num_heads_upsample == -1: - num_heads_upsample = num_heads - - self.in_channels = in_channels - self.spectrogram_channels = spectrogram_channels - self.model_channels = model_channels - self.out_channels = out_channels - self.num_res_blocks = num_res_blocks - self.attention_resolutions = attention_resolutions - self.dropout = dropout - self.channel_mult = channel_mult - self.conv_resample = conv_resample - self.dtype = torch.float16 if use_fp16 else torch.float32 - self.num_heads = num_heads - self.num_head_channels = num_head_channels - self.num_heads_upsample = num_heads_upsample - self.dims = dims - self.spectrogram_conditioning_levels = spectrogram_conditioning_levels - self.scale_steps = scale_steps - - self.encoder = DiscreteEncoder(spectrogram_channels, model_channels*4, quantize_dim, dropout, scale_steps) - #self.quantizer = Quantize(quantize_dim, num_discrete_codes, balancing_heuristic=True) - self.quantizer = GumbelQuantizer(quantize_dim, quantize_dim, num_discrete_codes) - # For recording codebook usage. - self.codes = torch.zeros((1228800,), dtype=torch.long) - self.code_ind = 0 - self.internal_step = 0 - decoder_channels = [model_channels * channel_mult[s-1] for s in spectrogram_conditioning_levels] - self.decoder = DiscreteDecoder(quantize_dim, decoder_channels[::-1], scale_steps) - - padding = 1 if kernel_size == 3 else 2 - - time_embed_dim = model_channels * 4 - self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - - self.conditioning_enabled = conditioning_inputs_provided - if conditioning_inputs_provided: - self.contextual_embedder = AudioMiniEncoder(self.spectrogram_channels, time_embed_dim) - self.query_gen = AudioMiniEncoder(decoder_channels[0], time_embed_dim) - self.embedding_combiner = EmbeddingCombiner(time_embed_dim) - - self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding) - ) - ] - ) - self._feature_size = model_channels - input_block_chans = [model_channels] - ch = model_channels - ds = 1 - self.convergence_convs = nn.ModuleList([]) - for level, mult in enumerate(channel_mult): - if level in spectrogram_conditioning_levels: - self.convergence_convs.append(conv_nd(dims, ch*2, ch, 1)) - - for _ in range(num_res_blocks): - layers = [ - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=mult * model_channels, - dims=dims, - use_scale_shift_norm=use_scale_shift_norm, - kernel_size=kernel_size, - ) - ] - ch = mult * model_channels - if ds in attention_resolutions: - layers.append( - AttentionBlock( - ch, - num_heads=num_heads, - num_head_channels=num_head_channels, - use_new_attention_order=use_new_attention_order, - ) - ) - self.input_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - input_block_chans.append(ch) - if level != len(channel_mult) - 1: - out_ch = ch - self.input_blocks.append( - TimestepEmbedSequential( - Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_steps - ) - ) - ) - ch = out_ch - input_block_chans.append(ch) - ds *= 2 - self._feature_size += ch - - self.middle_block = TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_scale_shift_norm=use_scale_shift_norm, - kernel_size=kernel_size, - ), - AttentionBlock( - ch, - num_heads=num_heads, - num_head_channels=num_head_channels, - use_new_attention_order=use_new_attention_order, - ), - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_scale_shift_norm=use_scale_shift_norm, - kernel_size=kernel_size, - ), - ) - self._feature_size += ch - - self.output_blocks = nn.ModuleList([]) - for level, mult in list(enumerate(channel_mult))[::-1]: - for i in range(num_res_blocks + 1): - ich = input_block_chans.pop() - layers = [ - ResBlock( - ch + ich, - time_embed_dim, - dropout, - out_channels=model_channels * mult, - dims=dims, - use_scale_shift_norm=use_scale_shift_norm, - kernel_size=kernel_size, - ) - ] - ch = model_channels * mult - if ds in attention_resolutions: - layers.append( - AttentionBlock( - ch, - num_heads=num_heads_upsample, - num_head_channels=num_head_channels, - use_new_attention_order=use_new_attention_order, - ) - ) - if level and i == num_res_blocks: - out_ch = ch - layers.append( - Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_steps) - ) - ds //= 2 - self.output_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - - self.out = nn.Sequential( - normalization(ch), - nn.SiLU(), - zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)), - ) - - def get_debug_values(self, step, __): - # Note: this is very poor design, but quantizer.get_temperature not only retrieves the temperature, it also updates the step and thus it is extremely important that this function get called regularly. - return {'histogram_codes': self.codes, 'quantizer_temperature': self.quantizer.get_temperature(step)} - - @torch.no_grad() - @eval_decorator - def get_codebook_indices(self, images): - img = self.norm(images) - logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1)) - sampled, commitment_loss, codes = self.codebook(logits) - return codes - - def _decode_continouous(self, x, timesteps, embeddings, conditioning_inputs, num_conditioning_signals): - if self.conditioning_enabled: - assert conditioning_inputs is not None - - spec_hs = self.decoder(embeddings)[::-1] - # Shape the spectrogram correctly. There is no guarantee it fits (though I probably should add an assertion here to make sure the resizing isn't too wacky.) - spec_hs = [nn.functional.interpolate(sh, size=(x.shape[-1]//self.scale_steps**self.spectrogram_conditioning_levels[i],), mode='nearest') for i, sh in enumerate(spec_hs)] - convergence_fns = list(self.convergence_convs) - - # Timestep embeddings and conditioning signals are combined using a small transformer. - hs = [] - emb1 = self.time_embed(timestep_embedding(timesteps, self.model_channels)) - if self.conditioning_enabled: - mask = get_mask_from_lengths(num_conditioning_signals+1, conditioning_inputs.shape[1]+1) # +1 to account for the timestep embeddings we'll add. - emb2 = torch.stack([self.contextual_embedder(ci.squeeze(1)) for ci in list(torch.chunk(conditioning_inputs, conditioning_inputs.shape[1], dim=1))], dim=1) - emb = torch.cat([emb1.unsqueeze(1), emb2], dim=1) - emb = self.embedding_combiner(emb, mask, self.query_gen(spec_hs[0])) - else: - emb = emb1 - - # The rest is the diffusion vocoder, built as a standard U-net. spec_h is gradually fed into the encoder. - next_spec = spec_hs.pop(0) - next_convergence_fn = convergence_fns.pop(0) - h = x.type(self.dtype) - for k, module in enumerate(self.input_blocks): - h = module(h, emb) - if next_spec is not None and h.shape[-1] == next_spec.shape[-1]: - h = torch.cat([h, next_spec], dim=1) - h = next_convergence_fn(h) - if len(spec_hs) > 0: - next_spec = spec_hs.pop(0) - next_convergence_fn = convergence_fns.pop(0) - else: - next_spec = None - hs.append(h) - assert len(spec_hs) == 0 - assert len(convergence_fns) == 0 - h = self.middle_block(h, emb) - for module in self.output_blocks: - h = torch.cat([h, hs.pop()], dim=1) - h = module(h, emb) - h = h.type(x.dtype) - return self.out(h) - - def decode(self, x, timesteps, codes, conditioning_inputs=None, num_conditioning_signals=None): - assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement. - embeddings = self.quantizer.embed_code(codes).permute((0,2,1)) - return self._decode_continouous(x, timesteps, embeddings, conditioning_inputs, num_conditioning_signals) - - def forward(self, x, timesteps, spectrogram, conditioning_inputs=None, num_conditioning_signals=None): - # Compute DVAE portion first. - spec_logits = self.encoder(spectrogram).permute((0,2,1)) - sampled, commitment_loss, codes = self.quantizer(spec_logits) - - if self.training: - # Compute from softmax outputs to preserve gradients. - embeddings = sampled.permute((0,2,1)) - else: - # Compute from codes only. - embeddings = self.quantizer.embed_code(codes).permute((0,2,1)) - - # This is so we can debug the distribution of codes being learned. - if self.internal_step % 50 == 0: - codes = codes.flatten() - l = codes.shape[0] - i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l - self.codes[i:i+l] = codes.cpu() - self.code_ind = self.code_ind + l - if self.code_ind >= self.codes.shape[0]: - self.code_ind = 0 - self.internal_step += 1 - - return self._decode_continouous(x, timesteps, embeddings, conditioning_inputs, num_conditioning_signals), commitment_loss - - -@register_model -def register_unet_diffusion_dvae(opt_net, opt): - return DiffusionDVAE(**opt_net['kwargs']) - - - -''' - - -class DiffusionDVAE(nn.Module): - def __init__( - self, - model_channels, - num_res_blocks, - in_channels=1, - out_channels=2, # mean and variance - spectrogram_channels=80, - spectrogram_conditioning_levels=[3,4,5], # Levels at which spectrogram conditioning is applied to the waveform. - dropout=0, - channel_mult=(1, 2, 4, 8, 16, 32, 64), - attention_resolutions=(16,32,64), - conv_resample=True, - dims=1, - use_fp16=False, - num_heads=1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - use_new_attention_order=False, - kernel_size=5, - quantize_dim=1024, - num_discrete_codes=8192, - scale_steps=4, - conditioning_inputs_provided=True, - ): - ''' - -# Test for ~4 second audio clip at 22050Hz -if __name__ == '__main__': - spec = torch.randn(4, 80, 160) - ts = torch.LongTensor([432, 234, 100, 555]) - model = DiffusionDVAE(model_channels=128, num_res_blocks=1, in_channels=80, out_channels=160, spectrogram_conditioning_levels=[1,2], - channel_mult=(1,2,4), attention_resolutions=[4], num_heads=4, kernel_size=3, scale_steps=2, conditioning_inputs_provided=False) - print(model(torch.randn_like(spec), ts, spec)[0].shape) diff --git a/codes/models/vqvae/vqvae_3.py b/codes/models/vqvae/vqvae_3.py deleted file mode 100644 index 9d58110f..00000000 --- a/codes/models/vqvae/vqvae_3.py +++ /dev/null @@ -1,179 +0,0 @@ -import torch -from torch import nn -from torch.nn import functional as F - -import torch.distributed as distributed - -from models.vqvae.vqvae import ResBlock, Quantize -from trainer.networks import register_model -from utils.util import checkpoint, opt_get - - -# Upsamples and blurs (similar to StyleGAN). Replaces ConvTranspose2D from the original paper. -class UpsampleConv(nn.Module): - def __init__(self, in_filters, out_filters, kernel_size, padding): - super().__init__() - self.conv = nn.Conv2d(in_filters, out_filters, kernel_size, padding=padding) - - def forward(self, x): - up = torch.nn.functional.interpolate(x, scale_factor=2) - return self.conv(up) - - -class Encoder(nn.Module): - def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride): - super().__init__() - - if stride == 4: - blocks = [ - nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2), - nn.LeakyReLU(inplace=True), - nn.Conv2d(channel // 2, channel, 5, stride=2, padding=2), - nn.LeakyReLU(inplace=True), - nn.Conv2d(channel, channel, 3, padding=1), - ] - - elif stride == 2: - blocks = [ - nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2), - nn.LeakyReLU(inplace=True), - nn.Conv2d(channel // 2, channel, 3, padding=1), - ] - - for i in range(n_res_block): - blocks.append(ResBlock(channel, n_res_channel)) - - blocks.append(nn.LeakyReLU(inplace=True)) - - self.blocks = nn.Sequential(*blocks) - - def forward(self, input): - return self.blocks(input) - - -class Decoder(nn.Module): - def __init__( - self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride - ): - super().__init__() - - blocks = [nn.Conv2d(in_channel, channel, 3, padding=1)] - - for i in range(n_res_block): - blocks.append(ResBlock(channel, n_res_channel)) - - blocks.append(nn.LeakyReLU(inplace=True)) - - if stride == 4: - blocks.extend( - [ - UpsampleConv(channel, channel // 2, 5, padding=2), - nn.LeakyReLU(inplace=True), - UpsampleConv( - channel // 2, out_channel, 5, padding=2 - ), - ] - ) - - elif stride == 2: - blocks.append( - UpsampleConv(channel, out_channel, 5, padding=2) - ) - - self.blocks = nn.Sequential(*blocks) - - def forward(self, input): - return self.blocks(input) - - -class VQVAE3(nn.Module): - def __init__( - self, - in_channel=3, - channel=128, - n_res_block=2, - n_res_channel=32, - codebook_dim=64, - codebook_size=512, - decay=0.99, - ): - super().__init__() - - self.initial_conv = nn.Sequential(*[nn.Conv2d(in_channel, 32, 3, padding=1), - nn.LeakyReLU(inplace=True)]) - self.enc_b = Encoder(32, channel, n_res_block, n_res_channel, stride=4) - self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2) - self.quantize_conv_t = nn.Conv2d(channel, codebook_dim, 1) - self.quantize_t = Quantize(codebook_dim, codebook_size) - self.dec_t = Decoder( - codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2 - ) - self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1) - self.quantize_b = Quantize(codebook_dim, codebook_size) - self.upsample_t = UpsampleConv( - codebook_dim, codebook_dim, 5, padding=2 - ) - self.dec = Decoder( - codebook_dim + codebook_dim, - 32, - channel, - n_res_block, - n_res_channel, - stride=4, - ) - self.final_conv = nn.Conv2d(32, in_channel, 3, padding=1) - - def forward(self, input): - quant_t, quant_b, diff, _, _ = self.encode(input) - dec = self.decode(quant_t, quant_b) - - return dec, diff - - def encode(self, input): - fea = self.initial_conv(input) - enc_b = checkpoint(self.enc_b, fea) - enc_t = checkpoint(self.enc_t, enc_b) - - quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1) - quant_t, diff_t, id_t = self.quantize_t(quant_t) - quant_t = quant_t.permute(0, 3, 1, 2) - diff_t = diff_t.unsqueeze(0) - - dec_t = checkpoint(self.dec_t, quant_t) - enc_b = torch.cat([dec_t, enc_b], 1) - - quant_b = checkpoint(self.quantize_conv_b, enc_b).permute(0, 2, 3, 1) - quant_b, diff_b, id_b = self.quantize_b(quant_b) - quant_b = quant_b.permute(0, 3, 1, 2) - diff_b = diff_b.unsqueeze(0) - - return quant_t, quant_b, diff_t + diff_b, id_t, id_b - - def decode(self, quant_t, quant_b): - upsample_t = self.upsample_t(quant_t) - quant = torch.cat([upsample_t, quant_b], 1) - dec = checkpoint(self.dec, quant) - dec = checkpoint(self.final_conv, dec) - - return dec - - def decode_code(self, code_t, code_b): - quant_t = self.quantize_t.embed_code(code_t) - quant_t = quant_t.permute(0, 3, 1, 2) - quant_b = self.quantize_b.embed_code(code_b) - quant_b = quant_b.permute(0, 3, 1, 2) - - dec = self.decode(quant_t, quant_b) - - return dec - - -@register_model -def register_vqvae3(opt_net, opt): - kw = opt_get(opt_net, ['kwargs'], {}) - return VQVAE3(**kw) - - -if __name__ == '__main__': - v = VQVAE3() - print(v(torch.randn(1,3,128,128))[0].shape) diff --git a/codes/models/vqvae/vqvae_3_hardswitch.py b/codes/models/vqvae/vqvae_3_hardswitch.py deleted file mode 100644 index fe118407..00000000 --- a/codes/models/vqvae/vqvae_3_hardswitch.py +++ /dev/null @@ -1,279 +0,0 @@ -import os -from time import time - -import torch -import torchvision -import torch.distributed as distributed -from torch import nn -from tqdm import tqdm - -from models.switched_conv.switched_conv_hard_routing import SwitchedConvHardRouting, \ - convert_conv_net_state_dict_to_switched_conv -from models.vqvae.vqvae import ResBlock, Quantize -from trainer.networks import register_model -from utils.util import checkpoint, opt_get - - -# Upsamples and blurs (similar to StyleGAN). Replaces ConvTranspose2D from the original paper. -class UpsampleConv(nn.Module): - def __init__(self, in_filters, out_filters, kernel_size, padding, cfg): - super().__init__() - self.conv = SwitchedConvHardRouting(in_filters, out_filters, kernel_size, breadth=cfg['breadth'], include_coupler=True, - coupler_mode=cfg['mode'], coupler_dim_in=in_filters, dropout_rate=cfg['dropout'], hard_en=cfg['hard_enabled']) - - def forward(self, x): - up = torch.nn.functional.interpolate(x, scale_factor=2) - return self.conv(up) - - -class Encoder(nn.Module): - def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride, cfg): - super().__init__() - - if stride == 4: - blocks = [ - nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2), - nn.LeakyReLU(inplace=True), - SwitchedConvHardRouting(channel // 2, channel, 5, breadth=cfg['breadth'], stride=2, include_coupler=True, - coupler_mode=cfg['mode'], coupler_dim_in=channel // 2, dropout_rate=cfg['dropout'], hard_en=cfg['hard_enabled']), - nn.LeakyReLU(inplace=True), - SwitchedConvHardRouting(channel, channel, 3, breadth=cfg['breadth'], include_coupler=True, coupler_mode=cfg['mode'], - coupler_dim_in=channel, dropout_rate=cfg['dropout'], hard_en=cfg['hard_enabled']), - ] - - elif stride == 2: - blocks = [ - nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2), - nn.LeakyReLU(inplace=True), - SwitchedConvHardRouting(channel // 2, channel, 3, breadth=cfg['breadth'], include_coupler=True, coupler_mode=cfg['mode'], - coupler_dim_in=channel // 2, dropout_rate=cfg['dropout'], hard_en=cfg['hard_enabled']), - ] - - for i in range(n_res_block): - blocks.append(ResBlock(channel, n_res_channel)) - - blocks.append(nn.LeakyReLU(inplace=True)) - - self.blocks = nn.Sequential(*blocks) - - def forward(self, input): - return self.blocks(input) - - -class Decoder(nn.Module): - def __init__( - self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride, cfg - ): - super().__init__() - - blocks = [SwitchedConvHardRouting(in_channel, channel, 3, breadth=cfg['breadth'], include_coupler=True, coupler_mode=cfg['mode'], - coupler_dim_in=in_channel, dropout_rate=cfg['dropout'], hard_en=cfg['hard_enabled'])] - - for i in range(n_res_block): - blocks.append(ResBlock(channel, n_res_channel)) - - blocks.append(nn.LeakyReLU(inplace=True)) - - if stride == 4: - blocks.extend( - [ - UpsampleConv(channel, channel // 2, 5, padding=2, cfg=cfg), - nn.LeakyReLU(inplace=True), - UpsampleConv( - channel // 2, out_channel, 5, padding=2, cfg=cfg - ), - ] - ) - - elif stride == 2: - blocks.append( - UpsampleConv(channel, out_channel, 5, padding=2, cfg=cfg) - ) - - self.blocks = nn.Sequential(*blocks) - - def forward(self, input): - return self.blocks(input) - - -class VQVAE3HardSwitch(nn.Module): - def __init__( - self, - in_channel=3, - channel=128, - n_res_block=2, - n_res_channel=32, - codebook_dim=64, - codebook_size=512, - decay=0.99, - cfg={'mode':'standard', 'breadth':16, 'hard_enabled': True, 'dropout': 0.4} - ): - super().__init__() - - self.cfg = cfg - self.initial_conv = nn.Sequential(*[nn.Conv2d(in_channel, 32, 3, padding=1), - nn.LeakyReLU(inplace=True)]) - self.enc_b = Encoder(32, channel, n_res_block, n_res_channel, stride=4, cfg=cfg) - self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2, cfg=cfg) - self.quantize_conv_t = nn.Conv2d(channel, codebook_dim, 1) - self.quantize_t = Quantize(codebook_dim, codebook_size) - self.dec_t = Decoder( - codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2, cfg=cfg - ) - self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1) - self.quantize_b = Quantize(codebook_dim, codebook_size) - self.upsample_t = UpsampleConv( - codebook_dim, codebook_dim, 5, padding=2, cfg=cfg - ) - self.dec = Decoder( - codebook_dim + codebook_dim, - 32, - channel, - n_res_block, - n_res_channel, - stride=4, - cfg=cfg - ) - self.final_conv = nn.Conv2d(32, in_channel, 3, padding=1) - - def forward(self, input): - quant_t, quant_b, diff, _, _ = self.encode(input) - dec = self.decode(quant_t, quant_b) - - return dec, diff - - def save_attention_to_image_rgb(self, output_file, attention_out, attention_size, cmap_discrete_name='viridis'): - from matplotlib import cm - magnitude, indices = torch.topk(attention_out, 3, dim=1) - indices = indices.cpu() - colormap = cm.get_cmap(cmap_discrete_name, attention_size) - img = torch.tensor(colormap(indices[:, 0, :, :].detach().numpy())) # TODO: use other k's - img = img.permute((0, 3, 1, 2)) - torchvision.utils.save_image(img, output_file) - - def visual_dbg(self, step, path): - convs = [self.dec.blocks[-1].conv, self.dec_t.blocks[-1].conv, self.enc_b.blocks[-4], self.enc_t.blocks[-4]] - for i, c in enumerate(convs): - self.save_attention_to_image_rgb(os.path.join(path, "%i_selector_%i.png" % (step, i+1)), c.last_select, 16) - - def get_debug_values(self, step, __): - switched_convs = [('enc_b_blk2', self.enc_b.blocks[2]), - ('enc_b_blk4', self.enc_b.blocks[4]), - ('enc_t_blk2', self.enc_t.blocks[2]), - ('dec_t_blk0', self.dec_t.blocks[0]), - ('dec_t_blk-1', self.dec_t.blocks[-1].conv), - ('dec_blk0', self.dec.blocks[0]), - ('dec_blk-1', self.dec.blocks[-1].conv), - ('dec_blk-3', self.dec.blocks[-3].conv)] - logs = {} - for name, swc in switched_convs: - logs[f'{name}_histogram_switch_usage'] = swc.latest_masks - return logs - - def encode(self, input): - fea = self.initial_conv(input) - enc_b = checkpoint(self.enc_b, fea) - enc_t = checkpoint(self.enc_t, enc_b) - - quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1) - quant_t, diff_t, id_t = self.quantize_t(quant_t) - quant_t = quant_t.permute(0, 3, 1, 2) - diff_t = diff_t.unsqueeze(0) - - dec_t = checkpoint(self.dec_t, quant_t) - enc_b = torch.cat([dec_t, enc_b], 1) - - quant_b = checkpoint(self.quantize_conv_b, enc_b).permute(0, 2, 3, 1) - quant_b, diff_b, id_b = self.quantize_b(quant_b) - quant_b = quant_b.permute(0, 3, 1, 2) - diff_b = diff_b.unsqueeze(0) - - return quant_t, quant_b, diff_t + diff_b, id_t, id_b - - def decode(self, quant_t, quant_b): - upsample_t = self.upsample_t(quant_t) - quant = torch.cat([upsample_t, quant_b], 1) - dec = checkpoint(self.dec, quant) - dec = checkpoint(self.final_conv, dec) - - return dec - - def decode_code(self, code_t, code_b): - quant_t = self.quantize_t.embed_code(code_t) - quant_t = quant_t.permute(0, 3, 1, 2) - quant_b = self.quantize_b.embed_code(code_b) - quant_b = quant_b.permute(0, 3, 1, 2) - - dec = self.decode(quant_t, quant_b) - - return dec - - - -def convert_weights(weights_file): - sd = torch.load(weights_file) - from models.vqvae.vqvae_3 import VQVAE3 - std_model = VQVAE3() - std_model.load_state_dict(sd) - nsd = convert_conv_net_state_dict_to_switched_conv(std_model, 16, ['quantize_conv_t', 'quantize_conv_b', - 'enc_b.blocks.0', 'enc_t.blocks.0', - 'conv.1', 'conv.3', 'initial_conv', 'final_conv']) - torch.save(nsd, "converted.pth") - - -@register_model -def register_vqvae3_hard_switch(opt_net, opt): - kw = opt_get(opt_net, ['kwargs'], {}) - vq = VQVAE3HardSwitch(**kw) - if distributed.is_initialized() and distributed.get_world_size() > 1: - vq = torch.nn.SyncBatchNorm.convert_sync_batchnorm(vq) - return vq - - -def performance_test(): - # For breadth=32: - # Custom_cuda_naive: 15.4 - # Torch_native: 29.2s - # - # For breadth=8 - # Custom_cuda_naive: 9.8 - # Torch_native: 10s - cfg = { - 'mode': 'lambda', - 'breadth': 16, - 'hard_enabled': True, - 'dropout': 0, - } - net = VQVAE3HardSwitch(cfg=cfg).to('cuda').double() - cfg['hard_enabled'] = False - netO = VQVAE3HardSwitch(cfg=cfg).double() - netO.load_state_dict(net.state_dict()) - netO = netO.cpu() - - loss = nn.L1Loss() - opt = torch.optim.Adam(net.parameters(), lr=1e-4) - started = time() - for j in tqdm(range(10)): - inp = torch.rand((4, 3, 64, 64), device='cuda', dtype=torch.double) - res = net(inp)[0] - l = loss(res, inp) - l.backward() - - res2 = netO(inp.cpu())[0] - l = loss(res2, inp.cpu()) - l.backward() - - for p, op in zip(net.parameters(), netO.parameters()): - diff = p.grad.cpu() - op.grad - print(diff.max()) - - opt.step() - net.zero_grad() - print("Elapsed: ", (time()-started)) - - -if __name__ == '__main__': - #v = VQVAE3HardSwitch() - #print(v(torch.randn(1,3,128,128))[0].shape) - #convert_weights("../../../experiments/vqvae_base.pth") - performance_test() diff --git a/codes/models/vqvae/vqvae_3_separated_coupler.py b/codes/models/vqvae/vqvae_3_separated_coupler.py deleted file mode 100644 index 9d58110f..00000000 --- a/codes/models/vqvae/vqvae_3_separated_coupler.py +++ /dev/null @@ -1,179 +0,0 @@ -import torch -from torch import nn -from torch.nn import functional as F - -import torch.distributed as distributed - -from models.vqvae.vqvae import ResBlock, Quantize -from trainer.networks import register_model -from utils.util import checkpoint, opt_get - - -# Upsamples and blurs (similar to StyleGAN). Replaces ConvTranspose2D from the original paper. -class UpsampleConv(nn.Module): - def __init__(self, in_filters, out_filters, kernel_size, padding): - super().__init__() - self.conv = nn.Conv2d(in_filters, out_filters, kernel_size, padding=padding) - - def forward(self, x): - up = torch.nn.functional.interpolate(x, scale_factor=2) - return self.conv(up) - - -class Encoder(nn.Module): - def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride): - super().__init__() - - if stride == 4: - blocks = [ - nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2), - nn.LeakyReLU(inplace=True), - nn.Conv2d(channel // 2, channel, 5, stride=2, padding=2), - nn.LeakyReLU(inplace=True), - nn.Conv2d(channel, channel, 3, padding=1), - ] - - elif stride == 2: - blocks = [ - nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2), - nn.LeakyReLU(inplace=True), - nn.Conv2d(channel // 2, channel, 3, padding=1), - ] - - for i in range(n_res_block): - blocks.append(ResBlock(channel, n_res_channel)) - - blocks.append(nn.LeakyReLU(inplace=True)) - - self.blocks = nn.Sequential(*blocks) - - def forward(self, input): - return self.blocks(input) - - -class Decoder(nn.Module): - def __init__( - self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride - ): - super().__init__() - - blocks = [nn.Conv2d(in_channel, channel, 3, padding=1)] - - for i in range(n_res_block): - blocks.append(ResBlock(channel, n_res_channel)) - - blocks.append(nn.LeakyReLU(inplace=True)) - - if stride == 4: - blocks.extend( - [ - UpsampleConv(channel, channel // 2, 5, padding=2), - nn.LeakyReLU(inplace=True), - UpsampleConv( - channel // 2, out_channel, 5, padding=2 - ), - ] - ) - - elif stride == 2: - blocks.append( - UpsampleConv(channel, out_channel, 5, padding=2) - ) - - self.blocks = nn.Sequential(*blocks) - - def forward(self, input): - return self.blocks(input) - - -class VQVAE3(nn.Module): - def __init__( - self, - in_channel=3, - channel=128, - n_res_block=2, - n_res_channel=32, - codebook_dim=64, - codebook_size=512, - decay=0.99, - ): - super().__init__() - - self.initial_conv = nn.Sequential(*[nn.Conv2d(in_channel, 32, 3, padding=1), - nn.LeakyReLU(inplace=True)]) - self.enc_b = Encoder(32, channel, n_res_block, n_res_channel, stride=4) - self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2) - self.quantize_conv_t = nn.Conv2d(channel, codebook_dim, 1) - self.quantize_t = Quantize(codebook_dim, codebook_size) - self.dec_t = Decoder( - codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2 - ) - self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1) - self.quantize_b = Quantize(codebook_dim, codebook_size) - self.upsample_t = UpsampleConv( - codebook_dim, codebook_dim, 5, padding=2 - ) - self.dec = Decoder( - codebook_dim + codebook_dim, - 32, - channel, - n_res_block, - n_res_channel, - stride=4, - ) - self.final_conv = nn.Conv2d(32, in_channel, 3, padding=1) - - def forward(self, input): - quant_t, quant_b, diff, _, _ = self.encode(input) - dec = self.decode(quant_t, quant_b) - - return dec, diff - - def encode(self, input): - fea = self.initial_conv(input) - enc_b = checkpoint(self.enc_b, fea) - enc_t = checkpoint(self.enc_t, enc_b) - - quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1) - quant_t, diff_t, id_t = self.quantize_t(quant_t) - quant_t = quant_t.permute(0, 3, 1, 2) - diff_t = diff_t.unsqueeze(0) - - dec_t = checkpoint(self.dec_t, quant_t) - enc_b = torch.cat([dec_t, enc_b], 1) - - quant_b = checkpoint(self.quantize_conv_b, enc_b).permute(0, 2, 3, 1) - quant_b, diff_b, id_b = self.quantize_b(quant_b) - quant_b = quant_b.permute(0, 3, 1, 2) - diff_b = diff_b.unsqueeze(0) - - return quant_t, quant_b, diff_t + diff_b, id_t, id_b - - def decode(self, quant_t, quant_b): - upsample_t = self.upsample_t(quant_t) - quant = torch.cat([upsample_t, quant_b], 1) - dec = checkpoint(self.dec, quant) - dec = checkpoint(self.final_conv, dec) - - return dec - - def decode_code(self, code_t, code_b): - quant_t = self.quantize_t.embed_code(code_t) - quant_t = quant_t.permute(0, 3, 1, 2) - quant_b = self.quantize_b.embed_code(code_b) - quant_b = quant_b.permute(0, 3, 1, 2) - - dec = self.decode(quant_t, quant_b) - - return dec - - -@register_model -def register_vqvae3(opt_net, opt): - kw = opt_get(opt_net, ['kwargs'], {}) - return VQVAE3(**kw) - - -if __name__ == '__main__': - v = VQVAE3() - print(v(torch.randn(1,3,128,128))[0].shape) diff --git a/codes/models/vqvae/vqvae_audio_xformer.py b/codes/models/vqvae/vqvae_audio_xformer.py deleted file mode 100644 index 8597bdd5..00000000 --- a/codes/models/vqvae/vqvae_audio_xformer.py +++ /dev/null @@ -1,149 +0,0 @@ -import math -from typing import Optional - -import torch -from torch import nn, Tensor -from torch.nn import functional as F - -import torch.distributed as distributed - -from models.vqvae.vqvae import Quantize -from trainer.networks import register_model -from utils.util import checkpoint, opt_get - - -class PositionalEncoding(nn.Module): - def __init__(self, d_model, dropout=0.1, max_len=5000): - super(PositionalEncoding, self).__init__() - self.dropout = nn.Dropout(p=dropout) - - pe = torch.zeros(max_len, d_model) - position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) - div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0) - self.register_buffer('pe', pe) - - def forward(self, x): - x = x + self.pe[:, :x.size(1)] - return self.dropout(x) - - -class TransformerEncoderLayer(nn.Module): - def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, - layer_norm_eps=1e-5, device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} - super(TransformerEncoderLayer, self).__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True, - **factory_kwargs) - # Implementation of Feedforward model - self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) - self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) - self.norm1 = nn.BatchNorm1d(d_model) - self.norm2 = nn.BatchNorm1d(d_model) - - self.activation = nn.ReLU() - - def __setstate__(self, state): - if 'activation' not in state: - state['activation'] = F.relu - super(TransformerEncoderLayer, self).__setstate__(state) - - def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor: - src2 = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] - src = src + src2 - src = self.norm1(src.permute(0,2,1)).permute(0,2,1) - src2 = self.linear2(self.activation(self.linear1(src))) - src = src + src2 - src = self.norm2(src.permute(0,2,1)).permute(0,2,1) - return src - - -class Encoder(nn.Module): - def __init__(self, in_channel, channel, output_breadth, num_layers=8, compression_factor=8): - super().__init__() - - self.compression_factor = compression_factor - self.pre_conv_stack = nn.Sequential(nn.Conv1d(in_channel, channel//4, kernel_size=3, padding=1), - nn.ReLU(), - nn.Conv1d(channel//4, channel//2, kernel_size=3, stride=2, padding=1), - nn.ReLU(), - nn.Conv1d(channel//2, channel//2, kernel_size=3, padding=1), - nn.ReLU(), - nn.Conv1d(channel//2, channel, kernel_size=3, stride=2, padding=1)) - self.norm1 = nn.BatchNorm1d(channel) - self.positional_embeddings = PositionalEncoding(channel, max_len=output_breadth//4) - self.encode = nn.TransformerEncoder(TransformerEncoderLayer(d_model=channel, nhead=4, dim_feedforward=channel*2), num_layers=num_layers) - - def forward(self, input): - x = self.norm1(self.pre_conv_stack(input)).permute(0,2,1) - x = self.positional_embeddings(x) - x = self.encode(x) - return x[:,:input.shape[2]//self.compression_factor,:] - - -class Decoder(nn.Module): - def __init__( - self, in_channel, out_channel, channel, output_breadth, num_layers=6 - ): - super().__init__() - - self.initial_conv = nn.Conv1d(in_channel, channel, kernel_size=1) - self.expand = output_breadth - self.positional_embeddings = PositionalEncoding(channel, max_len=output_breadth) - self.encode = nn.TransformerEncoder(TransformerEncoderLayer(d_model=channel, nhead=4, dim_feedforward=channel*2), num_layers=num_layers) - self.final_conv_stack = nn.Sequential(nn.Conv1d(channel, channel, kernel_size=3, padding=1), - nn.ReLU(), - nn.Conv1d(channel, out_channel, kernel_size=3, padding=1)) - - def forward(self, input): - x = self.initial_conv(input.permute(0,2,1)).permute(0,2,1) - x = nn.functional.pad(x, (0,0,0, self.expand-input.shape[1])) - x = self.positional_embeddings(x) - x = self.encode(x).permute(0,2,1) - return self.final_conv_stack(x) - - -class VQVAE(nn.Module): - def __init__( - self, - data_channels=1, - channel=256, - codebook_dim=256, - codebook_size=512, - breadth=80, - ): - super().__init__() - - self.enc = Encoder(data_channels, channel, breadth) - self.quantize_dense = nn.Linear(channel, codebook_dim) - self.quantize = Quantize(codebook_dim, codebook_size) - self.dec = Decoder(codebook_dim, data_channels, channel, breadth) - - def forward(self, input): - input = input.unsqueeze(1) - quant, diff, _ = self.encode(input) - dec = checkpoint(self.dec, quant) - dec = dec.squeeze(1) - return dec, diff - - def encode(self, input): - enc = checkpoint(self.enc, input) - quant = self.quantize_dense(enc) - quant, diff, id = self.quantize(quant) - diff = diff.unsqueeze(0) - return quant, diff, id - - -@register_model -def register_vqvae_xform_audio(opt_net, opt): - kw = opt_get(opt_net, ['kwargs'], {}) - vq = VQVAE(**kw) - return vq - - -if __name__ == '__main__': - model = VQVAE() - res=model(torch.randn(4,80)) - print(res[0].shape) \ No newline at end of file diff --git a/codes/models/vqvae/vqvae_no_conv_transpose.py b/codes/models/vqvae/vqvae_no_conv_transpose.py deleted file mode 100644 index 2120c937..00000000 --- a/codes/models/vqvae/vqvae_no_conv_transpose.py +++ /dev/null @@ -1,265 +0,0 @@ -# Copyright 2018 The Sonnet Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - - -# This is an alternative implementation of VQVAE that uses convolutions with kernels of size 5 and -# a "standard" upsampler rather than ConvTranspose. - - -import torch -from torch import nn -from torch.nn import functional as F - -import torch.distributed as distributed - -from trainer.networks import register_model -from utils.util import checkpoint, opt_get - - -# Upsamples and blurs (similar to StyleGAN). Replaces ConvTranspose2D from the original paper. -class UpsampleConv(nn.Module): - def __init__(self, in_filters, out_filters, kernel_size, padding): - super().__init__() - self.conv = nn.Conv2d(in_filters, out_filters, kernel_size, padding=padding) - - def forward(self, x): - up = torch.nn.functional.interpolate(x, scale_factor=2) - return self.conv(up) - - -class Quantize(nn.Module): - def __init__(self, dim, n_embed, decay=0.99, eps=1e-5): - super().__init__() - - self.dim = dim - self.n_embed = n_embed - self.decay = decay - self.eps = eps - - embed = torch.randn(dim, n_embed) - self.register_buffer("embed", embed) - self.register_buffer("cluster_size", torch.zeros(n_embed)) - self.register_buffer("embed_avg", embed.clone()) - - def forward(self, input): - flatten = input.reshape(-1, self.dim) - dist = ( - flatten.pow(2).sum(1, keepdim=True) - - 2 * flatten @ self.embed - + self.embed.pow(2).sum(0, keepdim=True) - ) - _, embed_ind = (-dist).max(1) - embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype) - embed_ind = embed_ind.view(*input.shape[:-1]) - quantize = self.embed_code(embed_ind) - - if self.training: - embed_onehot_sum = embed_onehot.sum(0) - embed_sum = flatten.transpose(0, 1) @ embed_onehot - - if distributed.is_initialized() and distributed.get_world_size() > 1: - distributed.all_reduce(embed_onehot_sum) - distributed.all_reduce(embed_sum) - - self.cluster_size.data.mul_(self.decay).add_( - embed_onehot_sum, alpha=1 - self.decay - ) - self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay) - n = self.cluster_size.sum() - cluster_size = ( - (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n - ) - embed_normalized = self.embed_avg / cluster_size.unsqueeze(0) - self.embed.data.copy_(embed_normalized) - - diff = (quantize.detach() - input).pow(2).mean() - quantize = input + (quantize - input).detach() - - return quantize, diff, embed_ind - - def embed_code(self, embed_id): - return F.embedding(embed_id, self.embed.transpose(0, 1)) - - -class ResBlock(nn.Module): - def __init__(self, in_channel, channel): - super().__init__() - - self.conv = nn.Sequential( - nn.ReLU(inplace=True), - nn.Conv2d(in_channel, channel, 3, padding=1), - nn.ReLU(inplace=True), - nn.Conv2d(channel, in_channel, 1), - ) - - def forward(self, input): - out = self.conv(input) - out += input - - return out - - -class Encoder(nn.Module): - def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride): - super().__init__() - - if stride == 4: - blocks = [ - nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2), - nn.ReLU(inplace=True), - nn.Conv2d(channel // 2, channel, 5, stride=2, padding=2), - nn.ReLU(inplace=True), - nn.Conv2d(channel, channel, 3, padding=1), - ] - - elif stride == 2: - blocks = [ - nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2), - nn.ReLU(inplace=True), - nn.Conv2d(channel // 2, channel, 3, padding=1), - ] - - for i in range(n_res_block): - blocks.append(ResBlock(channel, n_res_channel)) - - blocks.append(nn.ReLU(inplace=True)) - - self.blocks = nn.Sequential(*blocks) - - def forward(self, input): - return self.blocks(input) - - -class Decoder(nn.Module): - def __init__( - self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride - ): - super().__init__() - - blocks = [nn.Conv2d(in_channel, channel, 3, padding=1)] - - for i in range(n_res_block): - blocks.append(ResBlock(channel, n_res_channel)) - - blocks.append(nn.ReLU(inplace=True)) - - if stride == 4: - blocks.extend( - [ - UpsampleConv(channel, channel // 2, 5, padding=2), - nn.ReLU(inplace=True), - UpsampleConv( - channel // 2, out_channel, 5, padding=2 - ), - ] - ) - - elif stride == 2: - blocks.append( - UpsampleConv(channel, out_channel, 5, padding=2) - ) - - self.blocks = nn.Sequential(*blocks) - - def forward(self, input): - return self.blocks(input) - - -class VQVAE(nn.Module): - def __init__( - self, - in_channel=3, - channel=128, - n_res_block=2, - n_res_channel=32, - codebook_dim=64, - codebook_size=512, - decay=0.99, - ): - super().__init__() - - self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4) - self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2) - self.quantize_conv_t = nn.Conv2d(channel, codebook_dim, 1) - self.quantize_t = Quantize(codebook_dim, codebook_size) - self.dec_t = Decoder( - codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2 - ) - self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1) - self.quantize_b = Quantize(codebook_dim, codebook_size*2) - self.upsample_t = UpsampleConv( - codebook_dim, codebook_dim, 5, padding=2 - ) - self.dec = Decoder( - codebook_dim + codebook_dim, - in_channel, - channel, - n_res_block, - n_res_channel, - stride=4, - ) - - def forward(self, input): - quant_t, quant_b, diff, _, _ = self.encode(input) - dec = self.decode(quant_t, quant_b) - - return dec, diff - - def encode(self, input): - enc_b = checkpoint(self.enc_b, input) - enc_t = checkpoint(self.enc_t, enc_b) - - quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1) - quant_t, diff_t, id_t = self.quantize_t(quant_t) - quant_t = quant_t.permute(0, 3, 1, 2) - diff_t = diff_t.unsqueeze(0) - - dec_t = checkpoint(self.dec_t, quant_t) - enc_b = torch.cat([dec_t, enc_b], 1) - - quant_b = checkpoint(self.quantize_conv_b, enc_b).permute(0, 2, 3, 1) - quant_b, diff_b, id_b = self.quantize_b(quant_b) - quant_b = quant_b.permute(0, 3, 1, 2) - diff_b = diff_b.unsqueeze(0) - - return quant_t, quant_b, diff_t + diff_b, id_t, id_b - - def decode(self, quant_t, quant_b): - upsample_t = self.upsample_t(quant_t) - quant = torch.cat([upsample_t, quant_b], 1) - dec = checkpoint(self.dec, quant) - - return dec - - def decode_code(self, code_t, code_b): - quant_t = self.quantize_t.embed_code(code_t) - quant_t = quant_t.permute(0, 3, 1, 2) - quant_b = self.quantize_b.embed_code(code_b) - quant_b = quant_b.permute(0, 3, 1, 2) - - dec = self.decode(quant_t, quant_b) - - return dec - - -@register_model -def register_vqvae_normalized(opt_net, opt): - kw = opt_get(opt_net, ['kwargs'], {}) - return VQVAE(**kw) - - -if __name__ == '__main__': - v = VQVAE() - print(v(torch.randn(1,3,128,128))[0].shape) diff --git a/codes/models/vqvae/vqvae_no_conv_transpose_hardswitched_lambda.py b/codes/models/vqvae/vqvae_no_conv_transpose_hardswitched_lambda.py deleted file mode 100644 index 6e3fd071..00000000 --- a/codes/models/vqvae/vqvae_no_conv_transpose_hardswitched_lambda.py +++ /dev/null @@ -1,293 +0,0 @@ -import os - -import torch -import torchvision -from torch import nn -from torch.nn import functional as F - -import torch.distributed as distributed - -from models.switched_conv.switched_conv_hard_routing import SwitchedConvHardRouting, \ - convert_conv_net_state_dict_to_switched_conv -from trainer.networks import register_model -from utils.util import checkpoint, opt_get - - -# Upsamples and blurs (similar to StyleGAN). Replaces ConvTranspose2D from the original paper. -class UpsampleConv(nn.Module): - def __init__(self, in_filters, out_filters, breadth, kernel_size, padding): - super().__init__() - self.conv = SwitchedConvHardRouting(in_filters, out_filters, kernel_size, breadth, include_coupler=True, coupler_mode='standard', coupler_dim_in=in_filters, dropout_rate=0.4) - - def forward(self, x): - up = torch.nn.functional.interpolate(x, scale_factor=2) - return self.conv(up) - - -class Quantize(nn.Module): - def __init__(self, dim, n_embed, decay=0.99, eps=1e-5): - super().__init__() - - self.dim = dim - self.n_embed = n_embed - self.decay = decay - self.eps = eps - - embed = torch.randn(dim, n_embed) - self.register_buffer("embed", embed) - self.register_buffer("cluster_size", torch.zeros(n_embed)) - self.register_buffer("embed_avg", embed.clone()) - - def forward(self, input): - flatten = input.reshape(-1, self.dim) - dist = ( - flatten.pow(2).sum(1, keepdim=True) - - 2 * flatten @ self.embed - + self.embed.pow(2).sum(0, keepdim=True) - ) - _, embed_ind = (-dist).max(1) - embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype) - embed_ind = embed_ind.view(*input.shape[:-1]) - quantize = self.embed_code(embed_ind) - - if self.training: - embed_onehot_sum = embed_onehot.sum(0) - embed_sum = flatten.transpose(0, 1) @ embed_onehot - - if distributed.is_initialized() and distributed.get_world_size() > 1: - distributed.all_reduce(embed_onehot_sum) - distributed.all_reduce(embed_sum) - - self.cluster_size.data.mul_(self.decay).add_( - embed_onehot_sum, alpha=1 - self.decay - ) - self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay) - n = self.cluster_size.sum() - cluster_size = ( - (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n - ) - embed_normalized = self.embed_avg / cluster_size.unsqueeze(0) - self.embed.data.copy_(embed_normalized) - - diff = (quantize.detach() - input).pow(2).mean() - quantize = input + (quantize - input).detach() - - return quantize, diff, embed_ind - - def embed_code(self, embed_id): - return F.embedding(embed_id, self.embed.transpose(0, 1)) - - -class ResBlock(nn.Module): - def __init__(self, in_channel, channel, breadth): - super().__init__() - - self.conv = nn.Sequential( - nn.ReLU(inplace=True), - nn.Conv2d(in_channel, channel, 3, padding=1), - nn.ReLU(inplace=True), - nn.Conv2d(channel, in_channel, 1), - ) - - def forward(self, input): - out = self.conv(input) - out += input - - return out - - -class Encoder(nn.Module): - def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride, breadth): - super().__init__() - - if stride == 4: - blocks = [ - nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2), - nn.ReLU(inplace=True), - SwitchedConvHardRouting(channel // 2, channel, 5, breadth, stride=2, include_coupler=True, coupler_mode='standard', coupler_dim_in=channel // 2, dropout_rate=0.4), - nn.ReLU(inplace=True), - SwitchedConvHardRouting(channel, channel, 3, breadth, include_coupler=True, coupler_mode='standard', coupler_dim_in=channel, dropout_rate=0.4), - ] - - elif stride == 2: - blocks = [ - nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2), - nn.ReLU(inplace=True), - SwitchedConvHardRouting(channel // 2, channel, 3, breadth, include_coupler=True, coupler_mode='standard', coupler_dim_in=channel // 2, dropout_rate=0.4), - ] - - for i in range(n_res_block): - blocks.append(ResBlock(channel, n_res_channel, breadth)) - - blocks.append(nn.ReLU(inplace=True)) - - self.blocks = nn.Sequential(*blocks) - - def forward(self, input): - return self.blocks(input) - - -class Decoder(nn.Module): - def __init__( - self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride, breadth - ): - super().__init__() - - blocks = [SwitchedConvHardRouting(in_channel, channel, 3, breadth, include_coupler=True, coupler_mode='standard', coupler_dim_in=in_channel, dropout_rate=0.4)] - - for i in range(n_res_block): - blocks.append(ResBlock(channel, n_res_channel, breadth)) - - blocks.append(nn.ReLU(inplace=True)) - - if stride == 4: - blocks.extend( - [ - UpsampleConv(channel, channel // 2, breadth, 5, padding=2), - nn.ReLU(inplace=True), - UpsampleConv( - channel // 2, out_channel, breadth, 5, padding=2 - ), - ] - ) - - elif stride == 2: - blocks.append( - UpsampleConv(channel, out_channel, breadth, 5, padding=2) - ) - - self.blocks = nn.Sequential(*blocks) - - def forward(self, input): - return self.blocks(input) - - -class VQVAE(nn.Module): - def __init__( - self, - in_channel=3, - channel=128, - n_res_block=2, - n_res_channel=32, - codebook_dim=64, - codebook_size=512, - decay=0.99, - breadth=8, - ): - super().__init__() - - self.breadth = breadth - self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4, breadth=breadth) - self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2, breadth=breadth) - self.quantize_conv_t = nn.Conv2d(channel, codebook_dim, 1) - self.quantize_t = Quantize(codebook_dim, codebook_size) - self.dec_t = Decoder( - codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2, breadth=breadth - ) - self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1) - self.quantize_b = Quantize(codebook_dim, codebook_size*2) - self.upsample_t = UpsampleConv( - codebook_dim, codebook_dim, breadth, 5, padding=2 - ) - self.dec = Decoder( - codebook_dim + codebook_dim, - in_channel, - channel, - n_res_block, - n_res_channel, - stride=4, - breadth=breadth - ) - - def forward(self, input): - quant_t, quant_b, diff, _, _ = self.encode(input) - dec = self.decode(quant_t, quant_b) - - return dec, diff - - def save_attention_to_image_rgb(self, output_file, attention_out, attention_size, cmap_discrete_name='viridis'): - from matplotlib import cm - magnitude, indices = torch.topk(attention_out, 3, dim=1) - indices = indices.cpu() - colormap = cm.get_cmap(cmap_discrete_name, attention_size) - img = torch.tensor(colormap(indices[:, 0, :, :].detach().numpy())) # TODO: use other k's - img = img.permute((0, 3, 1, 2)) - torchvision.utils.save_image(img, output_file) - - def visual_dbg(self, step, path): - convs = [self.dec.blocks[-1].conv, self.dec_t.blocks[-1].conv, self.enc_b.blocks[-4], self.enc_t.blocks[-4]] - for i, c in enumerate(convs): - self.save_attention_to_image_rgb(os.path.join(path, "%i_selector_%i.png" % (step, i+1)), c.last_select, self.breadth) - - def get_debug_values(self, step, __): - switched_convs = [('enc_b_blk2', self.enc_b.blocks[2]), - ('enc_b_blk4', self.enc_b.blocks[4]), - ('enc_t_blk2', self.enc_t.blocks[2]), - ('dec_t_blk0', self.dec_t.blocks[0]), - ('dec_t_blk-1', self.dec_t.blocks[-1].conv), - ('dec_blk0', self.dec.blocks[0]), - ('dec_blk-1', self.dec.blocks[-1].conv), - ('dec_blk-3', self.dec.blocks[-3].conv)] - logs = {} - for name, swc in switched_convs: - logs[f'{name}_histogram_switch_usage'] = swc.latest_masks - return logs - - def encode(self, input): - enc_b = checkpoint(self.enc_b, input) - enc_t = checkpoint(self.enc_t, enc_b) - - quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1) - quant_t, diff_t, id_t = self.quantize_t(quant_t) - quant_t = quant_t.permute(0, 3, 1, 2) - diff_t = diff_t.unsqueeze(0) - - dec_t = checkpoint(self.dec_t, quant_t) - enc_b = torch.cat([dec_t, enc_b], 1) - - quant_b = checkpoint(self.quantize_conv_b, enc_b).permute(0, 2, 3, 1) - quant_b, diff_b, id_b = self.quantize_b(quant_b) - quant_b = quant_b.permute(0, 3, 1, 2) - diff_b = diff_b.unsqueeze(0) - - return quant_t, quant_b, diff_t + diff_b, id_t, id_b - - def decode(self, quant_t, quant_b): - upsample_t = self.upsample_t(quant_t) - quant = torch.cat([upsample_t, quant_b], 1) - dec = checkpoint(self.dec, quant) - - return dec - - def decode_code(self, code_t, code_b): - quant_t = self.quantize_t.embed_code(code_t) - quant_t = quant_t.permute(0, 3, 1, 2) - quant_b = self.quantize_b.embed_code(code_b) - quant_b = quant_b.permute(0, 3, 1, 2) - - dec = self.decode(quant_t, quant_b) - - return dec - - -def convert_weights(weights_file): - sd = torch.load(weights_file) - import models.vqvae.vqvae_no_conv_transpose as stdvq - std_model = stdvq.VQVAE() - std_model.load_state_dict(sd) - nsd = convert_conv_net_state_dict_to_switched_conv(std_model, 8, ['quantize_conv_t', 'quantize_conv_b', - 'enc_b.blocks.0', 'enc_t.blocks.0', - 'conv.1', 'conv.3']) - torch.save(nsd, "converted.pth") - - -@register_model -def register_vqvae_norm_hard_switched_conv_lambda(opt_net, opt): - kw = opt_get(opt_net, ['kwargs'], {}) - return VQVAE(**kw) - - -if __name__ == '__main__': - v = VQVAE(breadth=8).cuda() - print(v(torch.randn(1,3,128,128).cuda())[0].shape) - #convert_weights("../../../experiments/50000_generator.pth") diff --git a/codes/models/vqvae/vqvae_no_conv_transpose_switched_lambda.py b/codes/models/vqvae/vqvae_no_conv_transpose_switched_lambda.py deleted file mode 100644 index 6a9f380a..00000000 --- a/codes/models/vqvae/vqvae_no_conv_transpose_switched_lambda.py +++ /dev/null @@ -1,276 +0,0 @@ -import os - -import torch -import torchvision -from torch import nn -from torch.nn import functional as F - -import torch.distributed as distributed - -from models.switched_conv.switched_conv import SwitchedConv, convert_conv_net_state_dict_to_switched_conv -from trainer.networks import register_model -from utils.util import checkpoint, opt_get - - -# Upsamples and blurs (similar to StyleGAN). Replaces ConvTranspose2D from the original paper. -class UpsampleConv(nn.Module): - def __init__(self, in_filters, out_filters, breadth, kernel_size, padding): - super().__init__() - self.conv = SwitchedConv(in_filters, out_filters, kernel_size, breadth, padding=padding, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_filters) - - def forward(self, x): - up = torch.nn.functional.interpolate(x, scale_factor=2) - return self.conv(up) - - -class Quantize(nn.Module): - def __init__(self, dim, n_embed, decay=0.99, eps=1e-5): - super().__init__() - - self.dim = dim - self.n_embed = n_embed - self.decay = decay - self.eps = eps - - embed = torch.randn(dim, n_embed) - self.register_buffer("embed", embed) - self.register_buffer("cluster_size", torch.zeros(n_embed)) - self.register_buffer("embed_avg", embed.clone()) - - def forward(self, input): - flatten = input.reshape(-1, self.dim) - dist = ( - flatten.pow(2).sum(1, keepdim=True) - - 2 * flatten @ self.embed - + self.embed.pow(2).sum(0, keepdim=True) - ) - _, embed_ind = (-dist).max(1) - embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype) - embed_ind = embed_ind.view(*input.shape[:-1]) - quantize = self.embed_code(embed_ind) - - if self.training: - embed_onehot_sum = embed_onehot.sum(0) - embed_sum = flatten.transpose(0, 1) @ embed_onehot - - if distributed.is_initialized() and distributed.get_world_size() > 1: - distributed.all_reduce(embed_onehot_sum) - distributed.all_reduce(embed_sum) - - self.cluster_size.data.mul_(self.decay).add_( - embed_onehot_sum, alpha=1 - self.decay - ) - self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay) - n = self.cluster_size.sum() - cluster_size = ( - (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n - ) - embed_normalized = self.embed_avg / cluster_size.unsqueeze(0) - self.embed.data.copy_(embed_normalized) - - diff = (quantize.detach() - input).pow(2).mean() - quantize = input + (quantize - input).detach() - - return quantize, diff, embed_ind - - def embed_code(self, embed_id): - return F.embedding(embed_id, self.embed.transpose(0, 1)) - - -class ResBlock(nn.Module): - def __init__(self, in_channel, channel, breadth): - super().__init__() - - self.conv = nn.Sequential( - nn.ReLU(inplace=True), - SwitchedConv(in_channel, channel, 3, breadth, padding=1, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel), - nn.ReLU(inplace=True), - SwitchedConv(channel, in_channel, 1, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel), - ) - - def forward(self, input): - out = self.conv(input) - out += input - - return out - - -class Encoder(nn.Module): - def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride, breadth): - super().__init__() - - if stride == 4: - blocks = [ - SwitchedConv(in_channel, channel // 2, 5, breadth, stride=2, padding=2, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel), - nn.ReLU(inplace=True), - SwitchedConv(channel // 2, channel, 5, breadth, stride=2, padding=2, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel // 2), - nn.ReLU(inplace=True), - SwitchedConv(channel, channel, 3, breadth, padding=1, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel), - ] - - elif stride == 2: - blocks = [ - SwitchedConv(in_channel, channel // 2, 5, breadth, stride=2, padding=2, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel), - nn.ReLU(inplace=True), - SwitchedConv(channel // 2, channel, 3, breadth, padding=1, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel // 2), - ] - - for i in range(n_res_block): - blocks.append(ResBlock(channel, n_res_channel, breadth)) - - blocks.append(nn.ReLU(inplace=True)) - - self.blocks = nn.Sequential(*blocks) - - def forward(self, input): - return self.blocks(input) - - -class Decoder(nn.Module): - def __init__( - self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride, breadth - ): - super().__init__() - - blocks = [SwitchedConv(in_channel, channel, 3, breadth, padding=1, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel)] - - for i in range(n_res_block): - blocks.append(ResBlock(channel, n_res_channel, breadth)) - - blocks.append(nn.ReLU(inplace=True)) - - if stride == 4: - blocks.extend( - [ - UpsampleConv(channel, channel // 2, breadth, 5, padding=2), - nn.ReLU(inplace=True), - UpsampleConv( - channel // 2, out_channel, breadth, 5, padding=2 - ), - ] - ) - - elif stride == 2: - blocks.append( - UpsampleConv(channel, out_channel, breadth, 5, padding=2) - ) - - self.blocks = nn.Sequential(*blocks) - - def forward(self, input): - return self.blocks(input) - - -class VQVAE(nn.Module): - def __init__( - self, - in_channel=3, - channel=128, - n_res_block=2, - n_res_channel=32, - codebook_dim=64, - codebook_size=512, - decay=0.99, - breadth=4, - ): - super().__init__() - - self.breadth = breadth - self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4, breadth=breadth) - self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2, breadth=breadth) - self.quantize_conv_t = nn.Conv2d(channel, codebook_dim, 1) - self.quantize_t = Quantize(codebook_dim, codebook_size) - self.dec_t = Decoder( - codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2, breadth=breadth - ) - self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1) - self.quantize_b = Quantize(codebook_dim, codebook_size*2) - self.upsample_t = UpsampleConv( - codebook_dim, codebook_dim, breadth, 5, padding=2 - ) - self.dec = Decoder( - codebook_dim + codebook_dim, - in_channel, - channel, - n_res_block, - n_res_channel, - stride=4, - breadth=breadth - ) - - def forward(self, input): - quant_t, quant_b, diff, _, _ = self.encode(input) - dec = self.decode(quant_t, quant_b) - - return dec, diff - - def save_attention_to_image_rgb(self, output_file, attention_out, attention_size, cmap_discrete_name='viridis'): - from matplotlib import cm - magnitude, indices = torch.topk(attention_out, 3, dim=1) - indices = indices.cpu() - colormap = cm.get_cmap(cmap_discrete_name, attention_size) - img = torch.tensor(colormap(indices[:, 0, :, :].detach().numpy())) # TODO: use other k's - img = img.permute((0, 3, 1, 2)) - torchvision.utils.save_image(img, output_file) - - def visual_dbg(self, step, path): - convs = [self.dec.blocks[-1].conv, self.dec_t.blocks[-1].conv, self.enc_b.blocks[-4], self.enc_t.blocks[-4]] - for i, c in enumerate(convs): - self.save_attention_to_image_rgb(os.path.join(path, "%i_selector_%i.png" % (step, i+1)), c.last_select, self.breadth) - - def encode(self, input): - enc_b = checkpoint(self.enc_b, input) - enc_t = checkpoint(self.enc_t, enc_b) - - quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1) - quant_t, diff_t, id_t = self.quantize_t(quant_t) - quant_t = quant_t.permute(0, 3, 1, 2) - diff_t = diff_t.unsqueeze(0) - - dec_t = checkpoint(self.dec_t, quant_t) - enc_b = torch.cat([dec_t, enc_b], 1) - - quant_b = checkpoint(self.quantize_conv_b, enc_b).permute(0, 2, 3, 1) - quant_b, diff_b, id_b = self.quantize_b(quant_b) - quant_b = quant_b.permute(0, 3, 1, 2) - diff_b = diff_b.unsqueeze(0) - - return quant_t, quant_b, diff_t + diff_b, id_t, id_b - - def decode(self, quant_t, quant_b): - upsample_t = self.upsample_t(quant_t) - quant = torch.cat([upsample_t, quant_b], 1) - dec = checkpoint(self.dec, quant) - - return dec - - def decode_code(self, code_t, code_b): - quant_t = self.quantize_t.embed_code(code_t) - quant_t = quant_t.permute(0, 3, 1, 2) - quant_b = self.quantize_b.embed_code(code_b) - quant_b = quant_b.permute(0, 3, 1, 2) - - dec = self.decode(quant_t, quant_b) - - return dec - - -def convert_weights(weights_file): - sd = torch.load(weights_file) - import models.vqvae.vqvae_no_conv_transpose as stdvq - std_model = stdvq.VQVAE() - std_model.load_state_dict(sd) - nsd = convert_conv_net_state_dict_to_switched_conv(std_model, 4, ['quantize_conv_t', 'quantize_conv_b']) - torch.save(nsd, "converted.pth") - - -@register_model -def register_vqvae_norm_switched_conv_lambda(opt_net, opt): - kw = opt_get(opt_net, ['kwargs'], {}) - return VQVAE(**kw) - - -if __name__ == '__main__': - #v = VQVAE() - #print(v(torch.randn(1,3,128,128))[0].shape) - convert_weights("../../../experiments/4000_generator.pth") diff --git a/codes/scripts/diffusion/diffusion_noise_surfer.py b/codes/scripts/diffusion/diffusion_noise_surfer.py index 68e4ea52..b934369a 100644 --- a/codes/scripts/diffusion/diffusion_noise_surfer.py +++ b/codes/scripts/diffusion/diffusion_noise_surfer.py @@ -55,7 +55,7 @@ if __name__ == "__main__": torch.backends.cudnn.benchmark = True want_metrics = False parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_diffusion_vocoder_10-17.yml') + parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_diffusion_vocoder_10-20.yml') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) utils.util.loaded_options = opt diff --git a/codes/train.py b/codes/train.py index 1cc43b7a..26be2cff 100644 --- a/codes/train.py +++ b/codes/train.py @@ -284,7 +284,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_diffusion_vocoder_clips.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_dvae_audio_clips_with_quantizer_compression.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()