diff --git a/codes/models/audio/mel2vec.py b/codes/models/audio/mel2vec.py index 251767e7..568f0099 100644 --- a/codes/models/audio/mel2vec.py +++ b/codes/models/audio/mel2vec.py @@ -542,7 +542,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): idxs = codevector_idx.view(batch_size, sequence_length, self.num_groups) return idxs - def forward(self, hidden_states, mask_time_indices=None): + def forward(self, hidden_states, mask_time_indices=None, return_probs=False): batch_size, sequence_length, hidden_size = hidden_states.shape # project to codevector dim @@ -580,6 +580,8 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): .view(batch_size, sequence_length, -1) ) + if return_probs: + return codevectors, perplexity, codevector_probs.view(batch_size, sequence_length, self.num_groups, self.num_vars) return codevectors, perplexity diff --git a/codes/models/audio/music/transformer_diffusion3.py b/codes/models/audio/music/transformer_diffusion3.py deleted file mode 100644 index 43641303..00000000 --- a/codes/models/audio/music/transformer_diffusion3.py +++ /dev/null @@ -1,257 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import TimestepBlock -from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding -from trainer.networks import register_model -from utils.util import checkpoint, print_network - - -def is_latent(t): - return t.dtype == torch.float - -def is_sequence(t): - return t.dtype == torch.long - - -class MultiGroupEmbedding(nn.Module): - def __init__(self, tokens, groups, dim): - super().__init__() - self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)]) - - def forward(self, x): - h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)] - return torch.cat(h, dim=-1) - - -class TimestepRotaryEmbedSequential(nn.Sequential, TimestepBlock): - def forward(self, x, emb, rotary_emb): - for layer in self: - if isinstance(layer, TimestepBlock): - x = layer(x, emb, rotary_emb) - else: - x = layer(x, rotary_emb) - return x - - -class AttentionBlock(TimestepBlock): - def __init__(self, dim, heads, dropout): - super().__init__() - self.attn = Attention(dim, heads=heads, causal=False, dropout=dropout, zero_init_output=False) - self.ff = FeedForward(dim, mult=1, dropout=dropout, zero_init_output=True) - self.rms_scale_norm = RMSScaleShiftNorm(dim) - - def forward(self, x, timestep_emb, rotary_emb): - h = self.rms_scale_norm(x, norm_scale_shift_inp=timestep_emb) - h, _, _, _ = checkpoint(self.attn, h, None, None, None, None, None, rotary_emb) - h = checkpoint(self.ff, h) - return h + x - - -class TransformerDiffusion(nn.Module): - """ - A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way? - """ - def __init__( - self, - model_channels=512, - num_layers=8, - in_channels=256, - in_latent_channels=512, - rotary_emb_dim=32, - token_count=8, - in_groups=None, - out_channels=512, # mean and variance - dropout=0, - use_fp16=False, - # Parameters for regularization. - unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. - ): - super().__init__() - - self.in_channels = in_channels - self.model_channels = model_channels - self.out_channels = out_channels - self.dropout = dropout - self.unconditioned_percentage = unconditioned_percentage - self.enable_fp16 = use_fp16 - heads = model_channels//64 - - self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1) - - self.time_embed = nn.Sequential( - linear(model_channels, model_channels), - nn.SiLU(), - linear(model_channels, model_channels), - ) - self.conditioning_embedder = nn.Sequential(nn.Conv1d(in_channels, model_channels // 2, 3, padding=1, stride=2), - nn.Conv1d(model_channels//2, model_channels,3,padding=1,stride=2)) - self.conditioning_encoder = Encoder( - dim=model_channels, - depth=4, - heads=heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - zero_init_branch_output=True, - ff_mult=1, - ) - - # Either code_converter or latent_converter is used, depending on what type of conditioning data is fed. - # This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally - # complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive - # transformer network. - if in_groups is None: - self.embeddings = nn.Embedding(token_count, model_channels) - else: - self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels) - self.latent_conditioner = nn.Sequential( - nn.Conv1d(in_latent_channels, model_channels, 3, padding=1), - Encoder( - dim=model_channels, - depth=2, - heads=heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - zero_init_branch_output=True, - ff_mult=1, - ) - ) - self.latent_fade = nn.Parameter(torch.zeros(1,1,model_channels)) - self.code_converter = Encoder( - dim=model_channels, - depth=3, - heads=heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - zero_init_branch_output=True, - ff_mult=1, - ) - - self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,model_channels)) - self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1) - - self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) - self.intg = nn.Linear(model_channels*2, model_channels) - self.layers = TimestepRotaryEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers)]) - - self.out = nn.Sequential( - normalization(model_channels), - nn.SiLU(), - zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)), - ) - - self.debug_codes = {} - - def get_grad_norm_parameter_groups(self): - groups = { - 'contextual_embedder': list(self.conditioning_embedder.parameters()), - 'layers': list(self.layers.parameters()) + list(self.inp_block.parameters()), - 'code_converters': list(self.embeddings.parameters()) + list(self.code_converter.parameters()) + list(self.latent_conditioner.parameters()), - 'time_embed': list(self.time_embed.parameters()), - } - return groups - - def timestep_independent(self, codes, conditioning_input, expected_seq_len, prenet_latent=None, return_code_pred=False): - cond_emb = self.conditioning_embedder(conditioning_input).permute(0,2,1) - cond_emb = self.conditioning_encoder(cond_emb)[:, 0] - - code_emb = self.embeddings(codes) - if prenet_latent is not None: - latent_conditioning = self.latent_conditioner(prenet_latent) - code_emb = code_emb + latent_conditioning * self.latent_fade - - unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device) - # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. - if self.training and self.unconditioned_percentage > 0: - unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1), - device=code_emb.device) < self.unconditioned_percentage - code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(codes.shape[0], 1, 1), - code_emb) - code_emb = self.code_converter(code_emb) - - expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=expected_seq_len, mode='nearest').permute(0,2,1) - if not return_code_pred: - return expanded_code_emb, cond_emb - else: - # Perform the mel_head computation on the pre-exanded code embeddings, then interpolate it separately. - mel_pred = self.mel_head(code_emb.permute(0,2,1)) - mel_pred = F.interpolate(mel_pred, size=expected_seq_len, mode='nearest') - # Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches. - # This is because we don't want that gradient being used to train parameters through the codes_embedder as - # it unbalances contributions to that network from the MSE loss. - mel_pred = mel_pred * unconditioned_batches.logical_not() - return expanded_code_emb, cond_emb, mel_pred - - - def forward(self, x, timesteps, codes=None, conditioning_input=None, prenet_latent=None, precomputed_code_embeddings=None, - precomputed_cond_embeddings=None, conditioning_free=False, return_code_pred=False): - if precomputed_code_embeddings is not None: - assert precomputed_cond_embeddings is not None, "Must specify both precomputed embeddings if one is specified" - assert codes is None and conditioning_input is None and prenet_latent is None, "Do not provide precomputed embeddings and the other parameters. It is unclear what you want me to do here." - assert not (return_code_pred and precomputed_code_embeddings is not None), "I cannot compute a code_pred output for you." - - unused_params = [] - if not return_code_pred: - unused_params.extend(list(self.mel_head.parameters())) - if conditioning_free: - code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1]) - unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) - unused_params.extend(list(self.latent_conditioner.parameters())) - else: - if precomputed_code_embeddings is not None: - code_emb = precomputed_code_embeddings - cond_emb = precomputed_cond_embeddings - else: - code_emb, cond_emb, mel_pred = self.timestep_independent(codes, conditioning_input, x.shape[-1], prenet_latent, True) - if prenet_latent is None: - unused_params.extend(list(self.latent_conditioner.parameters()) + [self.latent_fade]) - unused_params.append(self.unconditioned_embedding) - - blk_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + cond_emb - x = self.inp_block(x).permute(0,2,1) - - rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device) - x = self.intg(torch.cat([x, code_emb], dim=-1)) - x = self.layers(x, blk_emb, rotary_pos_emb) - - x = x.float().permute(0,2,1) - out = self.out(x) - - # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors. - extraneous_addition = 0 - for p in unused_params: - extraneous_addition = extraneous_addition + p.mean() - out = out + extraneous_addition * 0 - - if return_code_pred: - return out, mel_pred - return out - - -@register_model -def register_transformer_diffusion3(opt_net, opt): - return TransformerDiffusion(**opt_net['kwargs']) - - -if __name__ == '__main__': - clip = torch.randn(2, 256, 400) - aligned_latent = torch.randn(2,100,512) - aligned_sequence = torch.randint(0,8,(2,100,8)) - cond = torch.randn(2, 256, 400) - ts = torch.LongTensor([600, 600]) - model = TransformerDiffusion(model_channels=2048, num_layers=8) - print_network(model) - #torchsummary.torchsummary.summary(model, clip, ts, aligned_sequence, cond, return_code_pred=True) - #o = model(clip, ts, aligned_sequence, cond, aligned_latent) - diff --git a/codes/models/audio/music/transformer_diffusion4.py b/codes/models/audio/music/transformer_diffusion4.py deleted file mode 100644 index a4e6e60c..00000000 --- a/codes/models/audio/music/transformer_diffusion4.py +++ /dev/null @@ -1,221 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import TimestepBlock -from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding -from trainer.networks import register_model -from utils.util import checkpoint, print_network - - -def is_latent(t): - return t.dtype == torch.float - -def is_sequence(t): - return t.dtype == torch.long - - -class MultiGroupEmbedding(nn.Module): - def __init__(self, tokens, groups, dim): - super().__init__() - self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)]) - - def forward(self, x): - h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)] - return torch.cat(h, dim=-1) - - -class TimestepRotaryEmbedSequential(nn.Sequential, TimestepBlock): - def forward(self, x, emb, rotary_emb): - for layer in self: - if isinstance(layer, TimestepBlock): - x = layer(x, emb, rotary_emb) - else: - x = layer(x, rotary_emb) - return x - - -class DietAttentionBlock(TimestepBlock): - def __init__(self, in_dim, dim, heads, dropout): - super().__init__() - self.rms_scale_norm = RMSScaleShiftNorm(in_dim) - self.proj = nn.Linear(in_dim, dim) - self.attn = Attention(dim, heads=heads, causal=False, dropout=dropout) - self.ff = FeedForward(dim, in_dim, mult=1, dropout=dropout, zero_init_output=True) - - def forward(self, x, timestep_emb, rotary_emb): - h = self.rms_scale_norm(x, norm_scale_shift_inp=timestep_emb) - h = self.proj(h) - h, _, _, _ = checkpoint(self.attn, h, None, None, None, None, None, rotary_emb) - h = checkpoint(self.ff, h) - return h + x - - -class TransformerDiffusion(nn.Module): - """ - A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way? - """ - def __init__( - self, - prenet_channels=256, - model_channels=512, - block_channels=256, - num_layers=8, - in_channels=256, - rotary_emb_dim=32, - token_count=8, - in_groups=None, - out_channels=512, # mean and variance - dropout=0, - use_fp16=False, - # Parameters for regularization. - unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. - ): - super().__init__() - - self.in_channels = in_channels - self.model_channels = model_channels - self.prenet_channels = prenet_channels - self.out_channels = out_channels - self.dropout = dropout - self.unconditioned_percentage = unconditioned_percentage - self.enable_fp16 = use_fp16 - - self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1) - - self.time_embed = nn.Sequential( - linear(prenet_channels, prenet_channels), - nn.SiLU(), - linear(prenet_channels, prenet_channels), - ) - prenet_heads = prenet_channels//64 - self.conditioning_embedder = nn.Sequential(nn.Conv1d(in_channels, prenet_channels // 2, 3, padding=1, stride=2), - nn.Conv1d(prenet_channels//2, prenet_channels,3,padding=1,stride=2)) - self.conditioning_encoder = Encoder( - dim=prenet_channels, - depth=4, - heads=prenet_heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - zero_init_branch_output=True, - ff_mult=1, - ) - - # Either code_converter or latent_converter is used, depending on what type of conditioning data is fed. - # This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally - # complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive - # transformer network. - if in_groups is None: - self.embeddings = nn.Embedding(token_count, prenet_channels) - else: - self.embeddings = MultiGroupEmbedding(token_count, in_groups, prenet_channels) - self.code_converter = Encoder( - dim=prenet_channels, - depth=3, - heads=prenet_heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - zero_init_branch_output=True, - ff_mult=1, - ) - - self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels)) - self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) - self.cond_intg = nn.Linear(prenet_channels*2, model_channels) - self.intg = nn.Linear(prenet_channels*2, model_channels) - self.layers = TimestepRotaryEmbedSequential(*[DietAttentionBlock(model_channels, block_channels, block_channels // 64, dropout) for _ in range(num_layers)]) - - self.out = nn.Sequential( - normalization(model_channels), - nn.SiLU(), - zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)), - ) - - self.debug_codes = {} - - def get_grad_norm_parameter_groups(self): - groups = { - 'contextual_embedder': list(self.conditioning_embedder.parameters()), - 'layers': list(self.layers.parameters()) + list(self.inp_block.parameters()), - 'code_converters': list(self.embeddings.parameters()) + list(self.code_converter.parameters()), - 'time_embed': list(self.time_embed.parameters()), - } - return groups - - def timestep_independent(self, codes, conditioning_input, expected_seq_len): - cond_emb = self.conditioning_embedder(conditioning_input).permute(0,2,1) - cond_emb = self.conditioning_encoder(cond_emb)[:, 0] - - code_emb = self.embeddings(codes) - - # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. - if self.training and self.unconditioned_percentage > 0: - unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1), - device=code_emb.device) < self.unconditioned_percentage - code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(codes.shape[0], 1, 1), - code_emb) - code_emb = self.code_converter(code_emb) - - expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=expected_seq_len, mode='nearest').permute(0,2,1) - return expanded_code_emb, cond_emb - - def forward(self, x, timesteps, codes=None, conditioning_input=None, precomputed_code_embeddings=None, - precomputed_cond_embeddings=None, conditioning_free=False): - if precomputed_code_embeddings is not None: - assert codes is None and conditioning_input is None, "Do not provide precomputed embeddings and the other parameters. It is unclear what you want me to do here." - - unused_params = [] - if conditioning_free: - code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1]) - unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) - else: - if precomputed_code_embeddings is not None: - code_emb = precomputed_code_embeddings - cond_emb = precomputed_cond_embeddings - else: - code_emb, cond_emb = self.timestep_independent(codes, conditioning_input, x.shape[-1]) - unused_params.append(self.unconditioned_embedding) - - blk_emb = torch.cat([self.time_embed(timestep_embedding(timesteps, self.prenet_channels)), cond_emb], dim=-1) - blk_emb = self.cond_intg(blk_emb) - x = self.inp_block(x).permute(0,2,1) - - rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device) - x = self.intg(torch.cat([x, code_emb], dim=-1)) - for layer in self.layers: - x = checkpoint(layer, x, blk_emb, rotary_pos_emb) - - x = x.float().permute(0,2,1) - out = self.out(x) - - # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors. - extraneous_addition = 0 - for p in unused_params: - extraneous_addition = extraneous_addition + p.mean() - out = out + extraneous_addition * 0 - - return out - - -@register_model -def register_transformer_diffusion4(opt_net, opt): - return TransformerDiffusion(**opt_net['kwargs']) - - -if __name__ == '__main__': - clip = torch.randn(2, 256, 400) - aligned_sequence = torch.randint(0,8,(2,100,8)) - cond = torch.randn(2, 256, 400) - ts = torch.LongTensor([600, 600]) - model = TransformerDiffusion(model_channels=3072, block_channels=1536, prenet_channels=1536, num_layers=16, in_groups=8) - torch.save(model, 'sample.pth') - print_network(model) - o = model(clip, ts, aligned_sequence, cond) - diff --git a/codes/models/audio/music/transformer_diffusion5.py b/codes/models/audio/music/transformer_diffusion5.py index 106913e6..3198de1a 100644 --- a/codes/models/audio/music/transformer_diffusion5.py +++ b/codes/models/audio/music/transformer_diffusion5.py @@ -164,8 +164,10 @@ class TransformerDiffusion(nn.Module): unused_params = [] if conditioning_free: - code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1]) - unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) + code_emb = self.unconditioned_embedding.repeat(x.shape[0], x.shape[-1], 1) + cond_emb = self.conditioning_embedder(conditioning_input).permute(0,2,1) + cond_emb = self.conditioning_encoder(cond_emb)[:, 0] + unused_params.extend(list(self.code_converter.parameters())) else: if precomputed_code_embeddings is not None: code_emb = precomputed_code_embeddings @@ -195,11 +197,70 @@ class TransformerDiffusion(nn.Module): return out +class TransformerDiffusionWithQuantizer(nn.Module): + def __init__(self, **kwargs): + super().__init__() + + self.diff = TransformerDiffusion(**kwargs) + from models.audio.mel2vec import ContrastiveTrainingWrapper + self.m2v = ContrastiveTrainingWrapper(mel_input_channels=256, inner_dim=1024, layers=24, dropout=0.1, + mask_time_prob=0, mask_time_length=6, num_negatives=100, codebook_size=16, codebook_groups=4, + disable_custom_linear_init=True, do_reconstruction_loss=True) + self.m2v.quantizer.temperature = self.m2v.min_gumbel_temperature + + self.codes = torch.zeros((3000000,), dtype=torch.long) + self.internal_step = 0 + self.code_ind = 0 + self.total_codes = 0 + + del self.m2v.m2v.encoder + + def update_for_step(self, step, *args): + self.internal_step = step + self.m2v.quantizer.temperature = max( + self.m2v.max_gumbel_temperature * self.m2v.gumbel_temperature_decay**step, + self.m2v.min_gumbel_temperature, + ) + + def forward(self, x, timesteps, truth_mel, conditioning_input, conditioning_free=False): + proj = self.m2v.m2v.input_blocks(truth_mel).permute(0,2,1) + _, proj = self.m2v.m2v.projector(proj) + vectors, _, probs = self.m2v.quantizer(proj, return_probs=True) + self.log_codes(probs) + return self.diff(x, timesteps, codes=vectors, conditioning_input=conditioning_input, conditioning_free=conditioning_free) + + def log_codes(self, codes): + if self.internal_step % 5 == 0: + codes = torch.argmax(codes, dim=-1) + codes = codes[:,:,0] + codes[:,:,1] * 16 + codes[:,:,2] * 16 ** 2 + codes[:,:,3] * 16 ** 3 + codes = codes.flatten() + l = codes.shape[0] + i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l + self.codes[i:i+l] = codes.cpu() + self.code_ind = self.code_ind + l + if self.code_ind >= self.codes.shape[0]: + self.code_ind = 0 + self.total_codes += 1 + + def get_debug_values(self, step, __): + if self.total_codes > 0: + return {'histogram_codes': self.codes[:self.total_codes]} + else: + return {} + + @register_model def register_transformer_diffusion5(opt_net, opt): return TransformerDiffusion(**opt_net['kwargs']) +@register_model +def register_transformer_diffusion5_with_quantizer(opt_net, opt): + return TransformerDiffusionWithQuantizer(**opt_net['kwargs']) + + +""" +# For TFD5 if __name__ == '__main__': clip = torch.randn(2, 256, 400) aligned_sequence = torch.randn(2,100,512) @@ -209,4 +270,20 @@ if __name__ == '__main__': torch.save(model, 'sample.pth') print_network(model) o = model(clip, ts, aligned_sequence, cond) +""" + +if __name__ == '__main__': + clip = torch.randn(2, 256, 400) + cond = torch.randn(2, 256, 400) + ts = torch.LongTensor([600, 600]) + model = TransformerDiffusionWithQuantizer(model_channels=2048, block_channels=1024, prenet_channels=1024, num_layers=16) + + quant_weights = torch.load('../experiments/m2v_music.pth') + diff_weights = torch.load('X:\\dlas\\experiments\\train_music_diffusion_tfd5\\models\\48000_generator_ema.pth') + model.m2v.load_state_dict(quant_weights, strict=False) + model.diff.load_state_dict(diff_weights) + + torch.save(model.state_dict(), 'sample.pth') + print_network(model) + o = model(clip, ts, clip, cond) diff --git a/codes/train.py b/codes/train.py index 7f9e11a2..5447cc1c 100644 --- a/codes/train.py +++ b/codes/train.py @@ -332,7 +332,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_encoder_build_ctc_alignments.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_diffusion_tfd.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() opt = option.parse(args.opt, is_train=True) diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py index 727fe284..5518cacf 100644 --- a/codes/trainer/eval/music_diffusion_fid.py +++ b/codes/trainer/eval/music_diffusion_fid.py @@ -60,6 +60,8 @@ class MusicDiffusionFid(evaluator.Evaluator): elif 'from_codes' == mode: self.diffusion_fn = self.perform_diffusion_from_codes self.local_modules['codegen'] = get_music_codegen() + elif 'from_codes_quant' == mode: + self.diffusion_fn = self.perform_diffusion_from_codes_quant self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 11000, 'filter_length': 16000, 'normalize': True, 'in': 'in', 'out': 'out'}, {}) @@ -92,12 +94,35 @@ class MusicDiffusionFid(evaluator.Evaluator): codes = codegen.get_codes(mel, project=True) mel_norm = normalize_mel(mel) gen_mel = self.diffuser.p_sample_loop(self.model, mel_norm.shape, - model_kwargs={'codes': codes, 'conditioning_input': mel_norm[:,:,:140]}) + model_kwargs={'codes': codes, 'conditioning_input': torch.zeros_like(mel_norm[:,:,:390])}) gen_mel_denorm = denormalize_mel(gen_mel) output_shape = (1,16,audio.shape[-1]//16) self.spec_decoder = self.spec_decoder.to(audio.device) - gen_wav = self.diffuser.p_sample_loop(self.spec_decoder, output_shape, model_kwargs={'aligned_conditioning': gen_mel_denorm}) + gen_wav = self.diffuser.p_sample_loop(self.spec_decoder, output_shape, + model_kwargs={'aligned_conditioning': gen_mel_denorm}) + gen_wav = pixel_shuffle_1d(gen_wav, 16) + + return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate + + def perform_diffusion_from_codes_quant(self, audio, sample_rate=22050): + if sample_rate != sample_rate: + real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0) + else: + real_resampled = audio + audio = audio.unsqueeze(0) + + mel = self.spec_fn({'in': audio})['out'] + mel_norm = normalize_mel(mel) + gen_mel = self.diffuser.p_sample_loop(self.model, mel_norm.shape, + model_kwargs={'truth_mel': mel, + 'conditioning_input': torch.zeros_like(mel_norm[:,:,:390])}) + + gen_mel_denorm = denormalize_mel(gen_mel) + output_shape = (1,16,audio.shape[-1]//16) + self.spec_decoder = self.spec_decoder.to(audio.device) + gen_wav = self.diffuser.p_sample_loop(self.spec_decoder, output_shape, + model_kwargs={'aligned_conditioning': gen_mel_denorm}) gen_wav = pixel_shuffle_1d(gen_wav, 16) return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate @@ -164,18 +189,19 @@ class MusicDiffusionFid(evaluator.Evaluator): # Put modules used for evaluation back into CPU memory. for k, mod in self.local_modules.items(): self.local_modules[k] = mod.cpu() + self.spec_decoder = self.spec_decoder.cpu() return {"frechet_distance": frechet_distance} if __name__ == '__main__': - diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_tfd.yml', 'generator', + diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_tfd_quant.yml', 'generator', also_load_savepoint=False, - load_path='X:\\dlas\\experiments\\train_music_diffusion_tfd\\models\\3000_generator_ema.pth' + load_path='X:\\dlas\\experiments\\music_tfd5_with_quantizer_basis.pth' ).cuda() - opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 500, - 'conditioning_free': True, 'conditioning_free_k': 1, - 'diffusion_schedule': 'linear', 'diffusion_type': 'from_codes'} - env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 26, 'device': 'cuda', 'opt': {}} + opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 100, + 'conditioning_free': False, 'conditioning_free_k': 2, + 'diffusion_schedule': 'linear', 'diffusion_type': 'from_codes_quant'} + env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 558, 'device': 'cuda', 'opt': {}} eval = MusicDiffusionFid(diffusion, opt_eval, env) print(eval.perform_eval())