From 3e64e847c28ba0318c265f01a6b3ee6829077158 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 23 Sep 2021 23:32:03 -0600 Subject: [PATCH] Gumbel quantizer --- codes/models/diffusion/diffusion_dvae.py | 6 ++++-- codes/models/vqvae/gumbel_quantizer.py | 24 ++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) create mode 100644 codes/models/vqvae/gumbel_quantizer.py diff --git a/codes/models/diffusion/diffusion_dvae.py b/codes/models/diffusion/diffusion_dvae.py index a8d4c50b..1e90a47f 100644 --- a/codes/models/diffusion/diffusion_dvae.py +++ b/codes/models/diffusion/diffusion_dvae.py @@ -7,6 +7,7 @@ import torch.nn as nn from models.gpt_voice.lucidrains_dvae import eval_decorator from models.gpt_voice.mini_encoder import AudioMiniEncoder, EmbeddingCombiner +from models.vqvae.gumbel_quantizer import GumbelQuantizer from models.vqvae.vqvae import Quantize from trainer.networks import register_model from utils.util import get_mask_from_lengths @@ -106,7 +107,8 @@ class DiffusionDVAE(nn.Module): self.scale_steps = scale_steps self.encoder = DiscreteEncoder(spectrogram_channels, model_channels*4, quantize_dim, dropout, scale_steps) - self.quantizer = Quantize(quantize_dim, num_discrete_codes, balancing_heuristic=True) + #self.quantizer = Quantize(quantize_dim, num_discrete_codes, balancing_heuristic=True) + self.quantizer = GumbelQuantizer(quantize_dim, quantize_dim, num_discrete_codes) # For recording codebook usage. self.codes = torch.zeros((131072,), dtype=torch.long) self.code_ind = 0 @@ -375,7 +377,7 @@ class DiffusionDVAE(nn.Module): # Test for ~4 second audio clip at 22050Hz if __name__ == '__main__': - spec = torch.randn(4, 80, 161) + spec = torch.randn(4, 80, 160) ts = torch.LongTensor([432, 234, 100, 555]) model = DiffusionDVAE(model_channels=128, num_res_blocks=1, in_channels=80, out_channels=160, spectrogram_conditioning_levels=[1,2], channel_mult=(1,2,4), attention_resolutions=[4], num_heads=4, kernel_size=3, scale_steps=2, conditioning_inputs_provided=False) diff --git a/codes/models/vqvae/gumbel_quantizer.py b/codes/models/vqvae/gumbel_quantizer.py new file mode 100644 index 00000000..21937cbf --- /dev/null +++ b/codes/models/vqvae/gumbel_quantizer.py @@ -0,0 +1,24 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum + + +class GumbelQuantizer(nn.Module): + def __init__(self, inp_dim, codebook_dim, num_tokens, straight_through=False, temperature=.9): + super().__init__() + self.to_logits = nn.Conv1d(inp_dim, num_tokens, 1) + self.codebook = nn.Embedding(num_tokens, codebook_dim) + self.straight_through = straight_through + self.temperature = temperature + + def embed_code(self, codes): + return self.codebook(codes) + + def forward(self, h): + h = h.permute(0,2,1) + logits = self.to_logits(h) + logits = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=self.straight_through) + codes = logits.argmax(dim=1).flatten(1) + sampled = einsum('b n l, n d -> b d l', logits, self.codebook.weight) + return sampled.permute(0,2,1), 0, codes \ No newline at end of file