From 8587a18717844cb8cb2d44db5f378606882bd21e Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 26 May 2022 20:19:09 -0600 Subject: [PATCH] fd fix --- codes/models/audio/music/flat_diffusion.py | 35 +++++++++++++------ .../audio/music/transformer_diffusion4.py | 4 +-- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/codes/models/audio/music/flat_diffusion.py b/codes/models/audio/music/flat_diffusion.py index 79c222bc..9def768a 100644 --- a/codes/models/audio/music/flat_diffusion.py +++ b/codes/models/audio/music/flat_diffusion.py @@ -19,6 +19,16 @@ def is_sequence(t): return t.dtype == torch.long +class MultiGroupEmbedding(nn.Module): + def __init__(self, tokens, groups, dim): + super().__init__() + self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)]) + + def forward(self, x): + h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)] + return torch.cat(h, dim=-1) + + class TimestepResBlock(TimestepBlock): def __init__( self, @@ -114,8 +124,8 @@ class FlatDiffusion(nn.Module): num_layers=8, in_channels=256, in_latent_channels=512, - in_vectors=8, - in_groups=8, + token_count=8, + in_groups=None, out_channels=512, # mean and variance dropout=0, use_fp16=False, @@ -147,7 +157,10 @@ class FlatDiffusion(nn.Module): # This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally # complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive # transformer network. - self.embeddings = nn.ModuleList([nn.Embedding(in_vectors, model_channels//in_groups) for _ in range(in_groups)]) + if in_groups is None: + self.embeddings = nn.Embedding(token_count, model_channels) + else: + self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels) self.latent_conditioner = nn.Sequential( nn.Conv1d(in_latent_channels, model_channels, 3, padding=1), AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), @@ -222,8 +235,8 @@ class FlatDiffusion(nn.Module): if is_latent(aligned_conditioning): code_emb = self.latent_conditioner(aligned_conditioning) else: - code_emb = [embedding(aligned_conditioning[:, :, i]) for i, embedding in enumerate(self.embeddings)] - code_emb = torch.cat(code_emb, dim=-1).permute(0,2,1) + code_emb = self.embeddings(aligned_conditioning) + code_emb = code_emb.permute(0,2,1) unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device) # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. @@ -245,19 +258,19 @@ class FlatDiffusion(nn.Module): return expanded_code_emb, mel_pred - def forward(self, x, timesteps, aligned_conditioning=None, conditioning_input=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False): + def forward(self, x, timesteps, codes=None, conditioning_input=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False): """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch of timesteps. - :param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced. + :param codes: an aligned latent or sequence of tokens providing useful data about the sample to be produced. :param conditioning_input: a full-resolution audio clip that is used as a reference to the style you want decoded. :param precomputed_aligned_embeddings: Embeddings returned from self.timestep_independent() :param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered. :return: an [N x C x ...] Tensor of outputs. """ - assert precomputed_aligned_embeddings is not None or (aligned_conditioning is not None and conditioning_input is not None) + assert precomputed_aligned_embeddings is not None or (codes is not None and conditioning_input is not None) assert not (return_code_pred and precomputed_aligned_embeddings is not None) # These two are mutually exclusive. unused_params = [] @@ -269,8 +282,8 @@ class FlatDiffusion(nn.Module): if precomputed_aligned_embeddings is not None: code_emb = precomputed_aligned_embeddings else: - code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_input, x.shape[-1], True) - if is_latent(aligned_conditioning): + code_emb, mel_pred = self.timestep_independent(codes, conditioning_input, x.shape[-1], True) + if is_latent(codes): unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) else: unused_params.extend(list(self.latent_conditioner.parameters())) @@ -324,7 +337,7 @@ if __name__ == '__main__': aligned_sequence = torch.randint(0,8,(2,100,8)) cond = torch.randn(2, 256, 400) ts = torch.LongTensor([600, 600]) - model = FlatDiffusion(512, layer_drop=.3, unconditioned_percentage=.5, train_mel_head=True) + model = FlatDiffusion(512, layer_drop=.3, unconditioned_percentage=.5, in_groups=8, train_mel_head=True) # Test with latent aligned conditioning #o = model(clip, ts, aligned_latent, cond) # Test with sequence aligned conditioning diff --git a/codes/models/audio/music/transformer_diffusion4.py b/codes/models/audio/music/transformer_diffusion4.py index b8e294f1..9ce30735 100644 --- a/codes/models/audio/music/transformer_diffusion4.py +++ b/codes/models/audio/music/transformer_diffusion4.py @@ -138,7 +138,7 @@ class TransformerDiffusion(nn.Module): self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,model_channels)) self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1) - self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) + self.rotary_embeddings = None if rotary_emb_dim is None else RotaryEmbedding(rotary_emb_dim) self.layers = SpecialSequential(*[AttentionBlock(model_channels, model_channels // 64, dropout) for _ in range(num_layers)]) self.out = nn.Sequential( @@ -217,7 +217,7 @@ class TransformerDiffusion(nn.Module): blk_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + cond_emb x = self.inp_block(x).permute(0,2,1) - rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device) + rotary_pos_emb = None if self.rotary_embeddings is None else self.rotary_embeddings(x.shape[1], x.device) x = self.layers(x, code_emb, blk_emb, rotary_pos_emb) x = x.float().permute(0,2,1)