diff --git a/codes/models/audio/music/flat_diffusion.py b/codes/models/audio/music/flat_diffusion.py index a9ae652f..79c222bc 100644 --- a/codes/models/audio/music/flat_diffusion.py +++ b/codes/models/audio/music/flat_diffusion.py @@ -5,6 +5,7 @@ import torch.nn as nn import torch.nn.functional as F from torch import autocast +from models.arch_util import ResBlock from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, TimestepBlock from trainer.networks import register_model @@ -18,7 +19,7 @@ def is_sequence(t): return t.dtype == torch.long -class ResBlock(TimestepBlock): +class TimestepResBlock(TimestepBlock): def __init__( self, channels, @@ -98,7 +99,7 @@ class ResBlock(TimestepBlock): class DiffusionLayer(TimestepBlock): def __init__(self, model_channels, dropout, num_heads): super().__init__() - self.resblk = ResBlock(model_channels, model_channels, dropout, model_channels, dims=1, use_scale_shift_norm=True) + self.resblk = TimestepResBlock(model_channels, model_channels, dropout, model_channels, dims=1, use_scale_shift_norm=True) self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True) def forward(self, x, time_emb): @@ -122,6 +123,7 @@ class FlatDiffusion(nn.Module): # Parameters for regularization. layer_drop=.1, unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. + train_mel_head=False, ): super().__init__() @@ -154,9 +156,11 @@ class FlatDiffusion(nn.Module): AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), ) self.code_converter = nn.Sequential( + ResBlock(dims=1, channels=model_channels, dropout=dropout), AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + ResBlock(dims=1, channels=model_channels, dropout=dropout), AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), - AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + ResBlock(dims=1, channels=model_channels, dropout=dropout), ) self.code_norm = normalization(model_channels) self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2), @@ -176,7 +180,7 @@ class FlatDiffusion(nn.Module): self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1) self.layers = nn.ModuleList([DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)] + - [ResBlock(model_channels, model_channels, dropout, dims=1, use_scale_shift_norm=True) for _ in range(3)]) + [TimestepResBlock(model_channels, model_channels, dropout, dims=1, use_scale_shift_norm=True) for _ in range(3)]) self.out = nn.Sequential( normalization(model_channels), @@ -184,6 +188,13 @@ class FlatDiffusion(nn.Module): zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)), ) + if train_mel_head: + for m in [self.conditioning_timestep_integrator, self.integrating_conv, self.layers, + self.out]: + for p in m.parameters(): + p.requires_grad = False + p.DO_NOT_TRAIN = True + def get_grad_norm_parameter_groups(self): groups = { 'minicoder': list(self.contextual_embedder.parameters()), @@ -213,7 +224,6 @@ class FlatDiffusion(nn.Module): else: code_emb = [embedding(aligned_conditioning[:, :, i]) for i, embedding in enumerate(self.embeddings)] code_emb = torch.cat(code_emb, dim=-1).permute(0,2,1) - code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1) unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device) # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. @@ -224,6 +234,7 @@ class FlatDiffusion(nn.Module): code_emb) expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest') expanded_code_emb = self.code_converter(expanded_code_emb) + expanded_code_emb = self.code_norm(expanded_code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1) if not return_code_pred: return expanded_code_emb @@ -313,7 +324,7 @@ if __name__ == '__main__': aligned_sequence = torch.randint(0,8,(2,100,8)) cond = torch.randn(2, 256, 400) ts = torch.LongTensor([600, 600]) - model = FlatDiffusion(512, layer_drop=.3, unconditioned_percentage=.5) + model = FlatDiffusion(512, layer_drop=.3, unconditioned_percentage=.5, train_mel_head=True) # Test with latent aligned conditioning #o = model(clip, ts, aligned_latent, cond) # Test with sequence aligned conditioning diff --git a/codes/scripts/audio/gen/use_mel2vec_codes.py b/codes/scripts/audio/gen/use_mel2vec_codes.py index 848561a2..207ab1bc 100644 --- a/codes/scripts/audio/gen/use_mel2vec_codes.py +++ b/codes/scripts/audio/gen/use_mel2vec_codes.py @@ -25,7 +25,7 @@ def recover_codegroups(codes, groups): if __name__ == '__main__': model = ContrastiveTrainingWrapper(mel_input_channels=256, inner_dim=1024, layers=24, dropout=0, mask_time_prob=0, mask_time_length=6, num_negatives=100, codebook_size=8, codebook_groups=8, disable_custom_linear_init=True) - model.load_state_dict(torch.load("X:\\dlas\\experiments\\train_music_mel2vec\\models\\29000_generator_ema.pth")) + model.load_state_dict(torch.load("../experiments/m2v_music.pth")) model.eval() wav = load_audio("Y:/separated/bt-music-1/100 Hits - Running Songs 2014 CD 2/100 Hits - Running Songs 2014 Cd2 - 02 - 7Th Heaven - Ain't Nothin' Goin' On But The Rent/00001/no_vocals.wav", 22050)