From 8f40108f5b9df810db35ed8a1979133310ba3c12 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 10 Jun 2022 14:51:59 -0600 Subject: [PATCH] lets try a different tact --- .../audio/music/transformer_diffusion8.py | 28 ++++++++++++++----- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/codes/models/audio/music/transformer_diffusion8.py b/codes/models/audio/music/transformer_diffusion8.py index 4d782f89..73a8b3e3 100644 --- a/codes/models/audio/music/transformer_diffusion8.py +++ b/codes/models/audio/music/transformer_diffusion8.py @@ -196,16 +196,18 @@ class TransformerDiffusion(nn.Module): class TransformerDiffusionWithQuantizer(nn.Module): - def __init__(self, freeze_quantizer_until=20000, **kwargs): + def __init__(self, train_quantizer_reconstruction_until=-1, freeze_quantizer_until=10000, **kwargs): super().__init__() - self.internal_step = 0 self.freeze_quantizer_until = freeze_quantizer_until + self.train_quantizer_reconstruction_until = train_quantizer_reconstruction_until self.diff = TransformerDiffusion(**kwargs) self.quantizer = MusicQuantizer2(inp_channels=256, inner_dim=[1024], codevector_dim=1024, codebook_size=256, codebook_groups=2, max_gumbel_temperature=4, min_gumbel_temperature=.5) self.quantizer.quantizer.temperature = self.quantizer.min_gumbel_temperature - del self.quantizer.up + if train_quantizer_reconstruction_until == -1: + # We won't be using the upsampler, so delete it. + del self.quantizer.up def update_for_step(self, step, *args): self.internal_step = step @@ -216,13 +218,24 @@ class TransformerDiffusionWithQuantizer(nn.Module): ) def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False): - quant_grad_enabled = self.internal_step > self.freeze_quantizer_until + diff_disabled = self.internal_step < self.train_quantizer_reconstruction_until + if diff_disabled: + mse, diversity_loss = self.quantizer(truth_mel) + + # Use the diff parameters so DDP doesn't give us grief. + unused = 0 + for p in self.diff.parameters(): + unused = unused + p.mean() * 0 + mse = mse + unused + return x, diversity_loss, mse + + quant_grad_enabled = self.internal_step >= self.freeze_quantizer_until with torch.set_grad_enabled(quant_grad_enabled): proj, diversity_loss = self.quantizer(truth_mel, return_decoder_latent=True) proj = proj.permute(0,2,1) - # Make sure this does not cause issues in DDP by explicitly using the parameters for nothing. if not quant_grad_enabled: + # Make sure this does not cause issues in DDP by explicitly using the parameters for nothing. unused = 0 for p in self.quantizer.parameters(): unused = unused + p.mean() * 0 @@ -232,7 +245,7 @@ class TransformerDiffusionWithQuantizer(nn.Module): diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, conditioning_free=conditioning_free) if disable_diversity: return diff - return diff, diversity_loss + return diff, diversity_loss, None def get_debug_values(self, step, __): if self.quantizer.total_codes > 0: @@ -317,7 +330,8 @@ def test_quant_model(): clip = torch.randn(2, 256, 400) ts = torch.LongTensor([600, 600]) model = TransformerDiffusionWithQuantizer(model_channels=2048, block_channels=1024, prenet_channels=1024, - input_vec_dim=1024, num_layers=16, prenet_layers=6) + input_vec_dim=1024, num_layers=16, prenet_layers=6, + train_quantizer_reconstruction_until=1000) model.get_grad_norm_parameter_groups() print_network(model)