Gumbel quantizer
This commit is contained in:
parent
c5297ccec6
commit
3e64e847c2
|
@ -7,6 +7,7 @@ import torch.nn as nn
|
|||
|
||||
from models.gpt_voice.lucidrains_dvae import eval_decorator
|
||||
from models.gpt_voice.mini_encoder import AudioMiniEncoder, EmbeddingCombiner
|
||||
from models.vqvae.gumbel_quantizer import GumbelQuantizer
|
||||
from models.vqvae.vqvae import Quantize
|
||||
from trainer.networks import register_model
|
||||
from utils.util import get_mask_from_lengths
|
||||
|
@ -106,7 +107,8 @@ class DiffusionDVAE(nn.Module):
|
|||
self.scale_steps = scale_steps
|
||||
|
||||
self.encoder = DiscreteEncoder(spectrogram_channels, model_channels*4, quantize_dim, dropout, scale_steps)
|
||||
self.quantizer = Quantize(quantize_dim, num_discrete_codes, balancing_heuristic=True)
|
||||
#self.quantizer = Quantize(quantize_dim, num_discrete_codes, balancing_heuristic=True)
|
||||
self.quantizer = GumbelQuantizer(quantize_dim, quantize_dim, num_discrete_codes)
|
||||
# For recording codebook usage.
|
||||
self.codes = torch.zeros((131072,), dtype=torch.long)
|
||||
self.code_ind = 0
|
||||
|
@ -375,7 +377,7 @@ class DiffusionDVAE(nn.Module):
|
|||
|
||||
# Test for ~4 second audio clip at 22050Hz
|
||||
if __name__ == '__main__':
|
||||
spec = torch.randn(4, 80, 161)
|
||||
spec = torch.randn(4, 80, 160)
|
||||
ts = torch.LongTensor([432, 234, 100, 555])
|
||||
model = DiffusionDVAE(model_channels=128, num_res_blocks=1, in_channels=80, out_channels=160, spectrogram_conditioning_levels=[1,2],
|
||||
channel_mult=(1,2,4), attention_resolutions=[4], num_heads=4, kernel_size=3, scale_steps=2, conditioning_inputs_provided=False)
|
||||
|
|
24
codes/models/vqvae/gumbel_quantizer.py
Normal file
24
codes/models/vqvae/gumbel_quantizer.py
Normal file
|
@ -0,0 +1,24 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import einsum
|
||||
|
||||
|
||||
class GumbelQuantizer(nn.Module):
|
||||
def __init__(self, inp_dim, codebook_dim, num_tokens, straight_through=False, temperature=.9):
|
||||
super().__init__()
|
||||
self.to_logits = nn.Conv1d(inp_dim, num_tokens, 1)
|
||||
self.codebook = nn.Embedding(num_tokens, codebook_dim)
|
||||
self.straight_through = straight_through
|
||||
self.temperature = temperature
|
||||
|
||||
def embed_code(self, codes):
|
||||
return self.codebook(codes)
|
||||
|
||||
def forward(self, h):
|
||||
h = h.permute(0,2,1)
|
||||
logits = self.to_logits(h)
|
||||
logits = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=self.straight_through)
|
||||
codes = logits.argmax(dim=1).flatten(1)
|
||||
sampled = einsum('b n l, n d -> b d l', logits, self.codebook.weight)
|
||||
return sampled.permute(0,2,1), 0, codes
|
Loading…
Reference in New Issue
Block a user