diff --git a/codes/models/audio/music/music_quantizer2.py b/codes/models/audio/music/music_quantizer2.py index d7de5ad5..8fa73c65 100644 --- a/codes/models/audio/music/music_quantizer2.py +++ b/codes/models/audio/music/music_quantizer2.py @@ -45,7 +45,7 @@ class Upsample(nn.Module): class ResBlock(nn.Module): - def __init__(self, chan, checkpoint=True): + def __init__(self, chan): super().__init__() self.net = nn.Sequential( nn.Conv1d(chan, chan, 3, padding = 1), @@ -56,13 +56,9 @@ class ResBlock(nn.Module): nn.SiLU(), zero_module(nn.Conv1d(chan, chan, 3, padding = 1)), ) - self.checkpoint = checkpoint def forward(self, x): - if self.checkpoint: - return checkpoint(self._forward, x) + x - else: - return self._forward(x) + x + return checkpoint(self._forward, x) + x def _forward(self, x): return self.net(x) @@ -169,7 +165,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): class MusicQuantizer2(nn.Module): def __init__(self, inp_channels=256, inner_dim=1024, codevector_dim=1024, down_steps=2, max_gumbel_temperature=2.0, min_gumbel_temperature=.5, gumbel_temperature_decay=.999995, - codebook_size=16, codebook_groups=4, checkpoint=True, + codebook_size=16, codebook_groups=4, # Downsample args: expressive_downsamples=False): super().__init__() @@ -195,14 +191,14 @@ class MusicQuantizer2(nn.Module): self.up = nn.Sequential(*[Upsample(inner_dim[i], inner_dim[i+1]) for i in range(len(inner_dim)-1)] + [nn.Conv1d(inner_dim[-1], inp_channels, kernel_size=3, padding=1)]) - self.encoder = nn.Sequential(ResBlock(inner_dim[0], checkpoint=checkpoint), - ResBlock(inner_dim[0], checkpoint=checkpoint), - ResBlock(inner_dim[0], checkpoint=checkpoint)) + self.encoder = nn.Sequential(ResBlock(inner_dim[0]), + ResBlock(inner_dim[0]), + ResBlock(inner_dim[0])) self.enc_norm = nn.LayerNorm(inner_dim[0], eps=1e-5) self.decoder = nn.Sequential(nn.Conv1d(codevector_dim, inner_dim[0], kernel_size=3, padding=1), - ResBlock(inner_dim[0], checkpoint=checkpoint), - ResBlock(inner_dim[0], checkpoint=checkpoint), - ResBlock(inner_dim[0], checkpoint=checkpoint)) + ResBlock(inner_dim[0]), + ResBlock(inner_dim[0]), + ResBlock(inner_dim[0])) self.codes = torch.zeros((3000000,), dtype=torch.long) self.internal_step = 0 @@ -228,18 +224,14 @@ class MusicQuantizer2(nn.Module): diversity = (self.quantizer.num_codevectors - perplexity) / self.quantizer.num_codevectors self.log_codes(codes) h = self.decoder(codevectors.permute(0,2,1)) - - if not hasattr(self, 'up') and return_decoder_latent: - return None, diversity, h + if return_decoder_latent: + return h, diversity reconstructed = self.up(h.float()) reconstructed = reconstructed[:, :, :orig_mel.shape[-1]] mse = F.mse_loss(reconstructed, orig_mel) - if return_decoder_latent: - return mse, diversity, h - else: - return mse, diversity + return mse, diversity def log_codes(self, codes): if self.internal_step % 5 == 0: diff --git a/codes/models/audio/music/transformer_diffusion8.py b/codes/models/audio/music/transformer_diffusion8.py index 8521437b..4d782f89 100644 --- a/codes/models/audio/music/transformer_diffusion8.py +++ b/codes/models/audio/music/transformer_diffusion8.py @@ -196,19 +196,16 @@ class TransformerDiffusion(nn.Module): class TransformerDiffusionWithQuantizer(nn.Module): - def __init__(self, freeze_quantizer_until=20000, quantizer_dims=[1024], no_reconstruction=True, **kwargs): + def __init__(self, freeze_quantizer_until=20000, **kwargs): super().__init__() self.internal_step = 0 self.freeze_quantizer_until = freeze_quantizer_until self.diff = TransformerDiffusion(**kwargs) - self.quantizer = MusicQuantizer2(inp_channels=kwargs['in_channels'], inner_dim=quantizer_dims, - codevector_dim=quantizer_dims[0], checkpoint=False, - codebook_size=256, codebook_groups=2, - max_gumbel_temperature=4, min_gumbel_temperature=.5) + self.quantizer = MusicQuantizer2(inp_channels=256, inner_dim=[1024], codevector_dim=1024, codebook_size=256, + codebook_groups=2, max_gumbel_temperature=4, min_gumbel_temperature=.5) self.quantizer.quantizer.temperature = self.quantizer.min_gumbel_temperature - if no_reconstruction: - del self.quantizer.up + del self.quantizer.up def update_for_step(self, step, *args): self.internal_step = step @@ -219,30 +216,27 @@ class TransformerDiffusionWithQuantizer(nn.Module): ) def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False): - mse, diversity_loss, proj = self.quantizer(truth_mel, return_decoder_latent=True) - proj = proj.permute(0,2,1) - quant_grad_enabled = self.internal_step > self.freeze_quantizer_until + with torch.set_grad_enabled(quant_grad_enabled): + proj, diversity_loss = self.quantizer(truth_mel, return_decoder_latent=True) + proj = proj.permute(0,2,1) + + # Make sure this does not cause issues in DDP by explicitly using the parameters for nothing. if not quant_grad_enabled: - proj = proj.detach() - # Make sure this does not cause issues in DDP by explicitly using the parameters for nothing. unused = 0 for p in self.quantizer.parameters(): unused = unused + p.mean() * 0 proj = proj + unused + diversity_loss = diversity_loss * 0 - diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, - conditioning_free=conditioning_free) - + diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, conditioning_free=conditioning_free) if disable_diversity: return diff - if mse is None: - return diff, diversity_loss - return diff, diversity_loss, mse + return diff, diversity_loss def get_debug_values(self, step, __): if self.quantizer.total_codes > 0: - return {'histogram_quant_codes': self.quantizer.codes[:self.quantizer.total_codes], + return {'histogram_codes': self.quantizer.codes[:self.quantizer.total_codes], 'gumbel_temperature': self.quantizer.quantizer.temperature} else: return {} @@ -320,26 +314,18 @@ def register_transformer_diffusion8_with_ar_prior(opt_net, opt): def test_quant_model(): - clip = torch.randn(2, 100, 401) + clip = torch.randn(2, 256, 400) ts = torch.LongTensor([600, 600]) - model = TransformerDiffusionWithQuantizer(in_channels=100, out_channels=200, quantizer_dims=[1024,768,512,384], - model_channels=2048, block_channels=1024, prenet_channels=1024, - input_vec_dim=1024, num_layers=16, prenet_layers=6, - no_reconstruction=False) - #model.get_grad_norm_parameter_groups() + model = TransformerDiffusionWithQuantizer(model_channels=2048, block_channels=1024, prenet_channels=1024, + input_vec_dim=1024, num_layers=16, prenet_layers=6) + model.get_grad_norm_parameter_groups() - #quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant_r4\\models\\5000_generator.pth') - #diff_weights = torch.load('X:\\dlas\\experiments\\train_music_diffusion_tfd5\\models\\48000_generator_ema.pth') - #model.quantizer.load_state_dict(quant_weights, strict=False) - #model.diff.load_state_dict(diff_weights) - - #torch.save(model.state_dict(), 'sample.pth') print_network(model) o = model(clip, ts, clip) def test_ar_model(): - clip = torch.randn(2, 256, 401) + clip = torch.randn(2, 256, 400) cond = torch.randn(2, 256, 400) ts = torch.LongTensor([600, 600]) model = TransformerDiffusionWithARPrior(model_channels=2048, block_channels=1024, prenet_channels=1024, @@ -357,8 +343,7 @@ def test_ar_model(): model.diff.load_state_dict(pruned_diff_weights, strict=False) torch.save(model.state_dict(), 'sample.pth') - model(clip, ts, cond, conditioning_input=cond) - + model(clip, ts, cond) if __name__ == '__main__':