From 0dee15f8757834efd272303e3f46dec15985808e Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 20 Oct 2021 21:19:38 -0600 Subject: [PATCH] base DVAE & vector_quantizer --- .../unet_diffusion_vocoder_with_ref.py | 4 +- codes/models/vqvae/dvae.py | 241 +++++++++++++++++ codes/models/vqvae/vector_quantizer.py | 245 ++++++++++++++++++ 3 files changed, 488 insertions(+), 2 deletions(-) create mode 100644 codes/models/vqvae/dvae.py create mode 100644 codes/models/vqvae/vector_quantizer.py diff --git a/codes/models/gpt_voice/unet_diffusion_vocoder_with_ref.py b/codes/models/gpt_voice/unet_diffusion_vocoder_with_ref.py index 051d8d5d..dd7800a9 100644 --- a/codes/models/gpt_voice/unet_diffusion_vocoder_with_ref.py +++ b/codes/models/gpt_voice/unet_diffusion_vocoder_with_ref.py @@ -285,7 +285,7 @@ class DiffusionVocoderWithRef(nn.Module): self.middle_block.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32) - def forward(self, x, timesteps, discrete_spectrogram, conditioning_inputs=None, num_conditioning_signals=None): + def forward(self, x, timesteps, spectrogram, conditioning_inputs=None, num_conditioning_signals=None): """ Apply the model to an input batch. @@ -311,7 +311,7 @@ class DiffusionVocoderWithRef(nn.Module): h = x.type(self.dtype) for k, module in enumerate(self.input_blocks): if isinstance(module, DiscreteSpectrogramConditioningBlock): - h = module(h, discrete_spectrogram) + h = module(h, spectrogram) else: h = module(h, emb) hs.append(h) diff --git a/codes/models/vqvae/dvae.py b/codes/models/vqvae/dvae.py new file mode 100644 index 00000000..6628bc59 --- /dev/null +++ b/codes/models/vqvae/dvae.py @@ -0,0 +1,241 @@ +import functools +import math +from math import sqrt + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch import einsum + +from models.gpt_voice.dvae_arch_playground.discretization_loss import DiscretizationLoss +from models.vqvae.vector_quantizer import VectorQuantize +from models.vqvae.vqvae import Quantize +from trainer.networks import register_model +from utils.util import opt_get + + +def eval_decorator(fn): + def inner(model, *args, **kwargs): + was_training = model.training + model.eval() + out = fn(model, *args, **kwargs) + model.train(was_training) + return out + return inner + + +class ResBlock(nn.Module): + def __init__(self, chan, conv, activation): + super().__init__() + self.net = nn.Sequential( + conv(chan, chan, 3, padding = 1), + activation(), + conv(chan, chan, 3, padding = 1), + activation(), + conv(chan, chan, 1) + ) + + def forward(self, x): + return self.net(x) + x + + +class UpsampledConv(nn.Module): + def __init__(self, conv, *args, **kwargs): + super().__init__() + assert 'stride' in kwargs.keys() + self.stride = kwargs['stride'] + del kwargs['stride'] + self.conv = conv(*args, **kwargs) + + def forward(self, x): + up = nn.functional.interpolate(x, scale_factor=self.stride, mode='nearest') + return self.conv(up) + + +class DiscreteVAE(nn.Module): + def __init__( + self, + positional_dims=2, + num_tokens = 512, + codebook_dim = 512, + num_layers = 3, + num_resnet_blocks = 0, + hidden_dim = 64, + channels = 3, + stride = 2, + kernel_size = 3, + activation = 'relu', + straight_through = False, + record_codes = False, + discretization_loss_averaging_steps = 100, + quantizer_use_cosine_sim=True, + quantizer_codebook_misses_to_expiration=40, + quantizer_codebook_embedding_compression=None, + ): + super().__init__() + assert num_layers >= 1, 'number of layers must be greater than or equal to 1' + has_resblocks = num_resnet_blocks > 0 + + self.num_tokens = num_tokens + self.num_layers = num_layers + self.straight_through = straight_through + self.positional_dims = positional_dims + self.discrete_loss = DiscretizationLoss(num_tokens, 2, 1 / (num_tokens*2), discretization_loss_averaging_steps) + + assert positional_dims > 0 and positional_dims < 3 # This VAE only supports 1d and 2d inputs for now. + if positional_dims == 2: + conv = nn.Conv2d + conv_transpose = functools.partial(UpsampledConv, conv) + else: + conv = nn.Conv1d + conv_transpose = functools.partial(UpsampledConv, conv) + + if activation == 'relu': + act = nn.ReLU + elif activation == 'silu': + act = nn.SiLU + else: + assert NotImplementedError() + + + enc_chans = [hidden_dim * 2 ** i for i in range(num_layers)] + dec_chans = list(reversed(enc_chans)) + + enc_chans = [channels, *enc_chans] + + dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0] + dec_chans = [dec_init_chan, *dec_chans] + + enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans)) + + enc_layers = [] + dec_layers = [] + + pad = (kernel_size - 1) // 2 + for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io): + enc_layers.append(nn.Sequential(conv(enc_in, enc_out, kernel_size, stride = stride, padding = pad), act())) + dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, kernel_size, stride = stride, padding = pad), act())) + + for _ in range(num_resnet_blocks): + dec_layers.insert(0, ResBlock(dec_chans[1], conv, act)) + enc_layers.append(ResBlock(enc_chans[-1], conv, act)) + + if num_resnet_blocks > 0: + dec_layers.insert(0, conv(codebook_dim, dec_chans[1], 1)) + + enc_layers.append(conv(enc_chans[-1], codebook_dim, 1)) + dec_layers.append(conv(dec_chans[-1], channels, 1)) + + self.encoder = nn.Sequential(*enc_layers) + self.quantizer = VectorQuantize(codebook_dim, num_tokens, codebook_dim=quantizer_codebook_embedding_compression, + use_cosine_sim=quantizer_use_cosine_sim, + max_codebook_misses_before_expiry=quantizer_codebook_misses_to_expiration) + self.decoder = nn.Sequential(*dec_layers) + + self.loss_fn = F.mse_loss + + self.record_codes = record_codes + if record_codes: + self.codes = torch.zeros((1228800,), dtype=torch.long) + self.code_ind = 0 + self.internal_step = 0 + + def get_debug_values(self, step, __): + if self.record_codes: + # Report annealing schedule + return {'histogram_codes': self.codes} + else: + return {} + + @torch.no_grad() + @eval_decorator + def get_codebook_indices(self, images): + logits = self.encoder(images).permute((0,2,3,1) if len(images.shape) == 4 else (0,2,1)) + sampled, codes, commitment_loss = self.quantizer(logits) + return codes + + def decode( + self, + img_seq + ): + self.log_codes(img_seq) + image_embeds = self.quantizer.decode(img_seq) + b, n, d = image_embeds.shape + + kwargs = {} + if self.positional_dims == 1: + arrange = 'b n d -> b d n' + else: + h = w = int(sqrt(n)) + arrange = 'b (h w) d -> b d h w' + kwargs = {'h': h, 'w': w} + image_embeds = rearrange(image_embeds, arrange, **kwargs) + images = [image_embeds] + for layer in self.decoder: + images.append(layer(images[-1])) + return images[-1], images[-2] + + def infer(self, img): + logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1)) + sampled, codes, commitment_loss = self.quantizer(logits) + return self.decode(codes) + + # Note: This module is not meant to be run in forward() except while training. It has special logic which performs + # evaluation using quantized values when it detects that it is being run in eval() mode, which will be substantially + # more lossy (but useful for determining network performance). + def forward( + self, + img + ): + logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1)) + sampled, codes, commitment_loss = self.quantizer(logits) + sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1)) + + if self.training: + out = sampled + for d in self.decoder: + out = d(out) + else: + # This is non-differentiable, but gives a better idea of how the network is actually performing. + out, _ = self.decode(codes) + + # reconstruction loss + recon_loss = self.loss_fn(img, out, reduction='none') + + # This is so we can debug the distribution of codes being learned. + self.log_codes(codes) + + return recon_loss, commitment_loss, out + + def log_codes(self, codes): + # This is so we can debug the distribution of codes being learned. + if self.record_codes and self.internal_step % 50 == 0: + codes = codes.flatten() + l = codes.shape[0] + i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l + self.codes[i:i+l] = codes.cpu() + self.code_ind = self.code_ind + l + if self.code_ind >= self.codes.shape[0]: + self.code_ind = 0 + self.internal_step += 1 + + +@register_model +def register_dvae(opt_net, opt): + return DiscreteVAE(**opt_get(opt_net, ['kwargs'], {})) + + +if __name__ == '__main__': + #v = DiscreteVAE() + #o=v(torch.randn(1,3,256,256)) + #print(o.shape) + v = DiscreteVAE(channels=80, positional_dims=1, num_tokens=4096, codebook_dim=1024, + hidden_dim=512, stride=2, num_resnet_blocks=2, kernel_size=3, num_layers=2, + quantizer_codebook_embedding_compression=64) + #v.eval() + loss, commitment, out = v(torch.randn(1,80,256)) + print(out.shape) + codes = v.get_codebook_indices(torch.randn(1,80,256)) + back, back_emb = v.decode(codes) + print(back.shape) diff --git a/codes/models/vqvae/vector_quantizer.py b/codes/models/vqvae/vector_quantizer.py new file mode 100644 index 00000000..96015b02 --- /dev/null +++ b/codes/models/vqvae/vector_quantizer.py @@ -0,0 +1,245 @@ +import torch +from torch import nn, einsum +import torch.nn.functional as F +from einops import rearrange, repeat + +from models.arch_util import l2norm, sample_vectors, default, ema_inplace + + +def kmeans(samples, num_clusters, num_iters = 10, use_cosine_sim = False): + dim, dtype, device = samples.shape[-1], samples.dtype, samples.device + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + if use_cosine_sim: + dists = samples @ means.t() + else: + diffs = rearrange(samples, 'n d -> n () d') - rearrange(means, 'c d -> () c d') + dists = -(diffs ** 2).sum(dim = -1) + + buckets = dists.max(dim = -1).indices + bins = torch.bincount(buckets, minlength = num_clusters) + zero_mask = bins == 0 + bins = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype = dtype) + new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d = dim), samples) + new_means = new_means / bins[..., None] + + if use_cosine_sim: + new_means = l2norm(new_means) + + means = torch.where(zero_mask[..., None], means, new_means) + + return means + +# distance types + +class EuclideanCodebook(nn.Module): + def __init__( + self, + dim, + codebook_size, + kmeans_init = False, + kmeans_iters = 10, + decay = 0.8, + eps = 1e-5 + ): + super().__init__() + self.decay = decay + init_fn = torch.randn if not kmeans_init else torch.zeros + embed = init_fn(codebook_size, dim) + + self.codebook_size = codebook_size + self.kmeans_iters = kmeans_iters + self.eps = eps + + self.register_buffer('initted', torch.Tensor([not kmeans_init])) + self.register_buffer('cluster_size', torch.zeros(codebook_size)) + self.register_buffer('embed', embed) + self.register_buffer('embed_avg', embed.clone()) + + def init_embed_(self, data): + embed = kmeans(data, self.codebook_size, self.kmeans_iters) + self.embed.data.copy_(embed) + self.embed_avg.data.copy_(embed.clone()) + self.initted.data.copy_(torch.Tensor([True])) + + def replace(self, samples, mask): + modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed) + self.embed.data.copy_(modified_codebook) + + def forward(self, x): + shape, dtype = x.shape, x.dtype + flatten = rearrange(x, '... d -> (...) d') + embed = self.embed.t() + + if not self.initted: + self.init_embed_(flatten) + + dist = -( + flatten.pow(2).sum(1, keepdim=True) + - 2 * flatten @ embed + + embed.pow(2).sum(0, keepdim=True) + ) + + embed_ind = dist.max(dim = -1).indices + embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(x.dtype) + embed_ind = embed_ind.view(*shape[:-1]) + quantize = F.embedding(embed_ind, self.embed) + + if self.training: + ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) + embed_sum = flatten.t() @ embed_onehot + ema_inplace(self.embed_avg, embed_sum.t(), self.decay) + cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum() + embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) + self.embed.data.copy_(embed_normalized) + + return quantize, embed_ind + +class CosineSimCodebook(nn.Module): + def __init__( + self, + dim, + codebook_size, + kmeans_init = False, + kmeans_iters = 10, + decay = 0.8, + eps = 1e-5 + ): + super().__init__() + self.decay = decay + + if not kmeans_init: + embed = l2norm(torch.randn(codebook_size, dim)) + else: + embed = torch.zeros(codebook_size, dim) + + self.codebook_size = codebook_size + self.kmeans_iters = kmeans_iters + self.eps = eps + + self.register_buffer('initted', torch.Tensor([not kmeans_init])) + self.register_buffer('embed', embed) + + def init_embed_(self, data): + embed = kmeans(data, self.codebook_size, self.kmeans_iters, use_cosine_sim = True) + self.embed.data.copy_(embed) + self.initted.data.copy_(torch.Tensor([True])) + + def replace(self, samples, mask): + samples = l2norm(samples) + modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed) + self.embed.data.copy_(modified_codebook) + + def forward(self, x): + shape, dtype = x.shape, x.dtype + flatten = rearrange(x, '... d -> (...) d') + flatten = l2norm(flatten) + + if not self.initted: + self.init_embed_(flatten) + + embed = l2norm(self.embed) + dist = flatten @ embed.t() + embed_ind = dist.max(dim = -1).indices + embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) + embed_ind = embed_ind.view(*shape[:-1]) + + quantize = F.embedding(embed_ind, self.embed) + + if self.training: + bins = embed_onehot.sum(0) + zero_mask = (bins == 0) + bins = bins.masked_fill(zero_mask, 1.) + + embed_sum = flatten.t() @ embed_onehot + embed_normalized = (embed_sum / bins.unsqueeze(0)).t() + embed_normalized = l2norm(embed_normalized) + embed_normalized = torch.where(zero_mask[..., None], embed, embed_normalized) + ema_inplace(self.embed, embed_normalized, self.decay) + + return quantize, embed_ind + +# main class + +class VectorQuantize(nn.Module): + def __init__( + self, + dim, + codebook_size, + n_embed = None, + codebook_dim = None, + decay = 0.8, + eps = 1e-5, + kmeans_init = False, + kmeans_iters = 10, + use_cosine_sim = False, + max_codebook_misses_before_expiry = 0 + ): + super().__init__() + n_embed = default(n_embed, codebook_size) + + codebook_dim = default(codebook_dim, dim) + requires_projection = codebook_dim != dim + self.project_in = nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity() + self.project_out = nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity() + + self.eps = eps + + klass = EuclideanCodebook if not use_cosine_sim else CosineSimCodebook + + self._codebook = klass( + dim = codebook_dim, + codebook_size = n_embed, + kmeans_init = kmeans_init, + kmeans_iters = kmeans_iters, + decay = decay, + eps = eps + ) + + self.codebook_size = codebook_size + self.max_codebook_misses_before_expiry = max_codebook_misses_before_expiry + + if max_codebook_misses_before_expiry > 0: + codebook_misses = torch.zeros(codebook_size) + self.register_buffer('codebook_misses', codebook_misses) + + @property + def codebook(self): + return self._codebook.codebook + + def decode(self, codes): + unembed = F.embedding(codes, self._codebook.embed) + return self.project_out(unembed) + + def expire_codes_(self, embed_ind, batch_samples): + if self.max_codebook_misses_before_expiry == 0: + return + + embed_ind = rearrange(embed_ind, '... -> (...)') + misses = torch.bincount(embed_ind, minlength = self.codebook_size) == 0 + self.codebook_misses += misses + + expired_codes = self.codebook_misses >= self.max_codebook_misses_before_expiry + if not torch.any(expired_codes): + return + + self.codebook_misses.masked_fill_(expired_codes, 0) + batch_samples = rearrange(batch_samples, '... d -> (...) d') + self._codebook.replace(batch_samples, mask = expired_codes) + + def forward(self, x): + x = self.project_in(x) + + quantize, embed_ind = self._codebook(x) + commit_loss = F.mse_loss(quantize.detach(), x) + + if self.training: + quantize = x + (quantize - x).detach() + self.expire_codes_(embed_ind, x) + + quantize = self.project_out(quantize) + return quantize, embed_ind, commit_loss