Gumbel quantizer

This commit is contained in:
James Betker 2021-09-23 23:32:03 -06:00
parent c5297ccec6
commit 3e64e847c2
2 changed files with 28 additions and 2 deletions

View File

@ -7,6 +7,7 @@ import torch.nn as nn
from models.gpt_voice.lucidrains_dvae import eval_decorator
from models.gpt_voice.mini_encoder import AudioMiniEncoder, EmbeddingCombiner
from models.vqvae.gumbel_quantizer import GumbelQuantizer
from models.vqvae.vqvae import Quantize
from trainer.networks import register_model
from utils.util import get_mask_from_lengths
@ -106,7 +107,8 @@ class DiffusionDVAE(nn.Module):
self.scale_steps = scale_steps
self.encoder = DiscreteEncoder(spectrogram_channels, model_channels*4, quantize_dim, dropout, scale_steps)
self.quantizer = Quantize(quantize_dim, num_discrete_codes, balancing_heuristic=True)
#self.quantizer = Quantize(quantize_dim, num_discrete_codes, balancing_heuristic=True)
self.quantizer = GumbelQuantizer(quantize_dim, quantize_dim, num_discrete_codes)
# For recording codebook usage.
self.codes = torch.zeros((131072,), dtype=torch.long)
self.code_ind = 0
@ -375,7 +377,7 @@ class DiffusionDVAE(nn.Module):
# Test for ~4 second audio clip at 22050Hz
if __name__ == '__main__':
spec = torch.randn(4, 80, 161)
spec = torch.randn(4, 80, 160)
ts = torch.LongTensor([432, 234, 100, 555])
model = DiffusionDVAE(model_channels=128, num_res_blocks=1, in_channels=80, out_channels=160, spectrogram_conditioning_levels=[1,2],
channel_mult=(1,2,4), attention_resolutions=[4], num_heads=4, kernel_size=3, scale_steps=2, conditioning_inputs_provided=False)

View File

@ -0,0 +1,24 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import einsum
class GumbelQuantizer(nn.Module):
def __init__(self, inp_dim, codebook_dim, num_tokens, straight_through=False, temperature=.9):
super().__init__()
self.to_logits = nn.Conv1d(inp_dim, num_tokens, 1)
self.codebook = nn.Embedding(num_tokens, codebook_dim)
self.straight_through = straight_through
self.temperature = temperature
def embed_code(self, codes):
return self.codebook(codes)
def forward(self, h):
h = h.permute(0,2,1)
logits = self.to_logits(h)
logits = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=self.straight_through)
codes = logits.argmax(dim=1).flatten(1)
sampled = einsum('b n l, n d -> b d l', logits, self.codebook.weight)
return sampled.permute(0,2,1), 0, codes