From 968660c248cdbb95ebf6ab59c086f035335cb0d1 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 20 May 2022 11:25:00 -0600 Subject: [PATCH] another update --- codes/models/audio/music/flat_diffusion.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/codes/models/audio/music/flat_diffusion.py b/codes/models/audio/music/flat_diffusion.py index 8e2c5ddd..a9ae652f 100644 --- a/codes/models/audio/music/flat_diffusion.py +++ b/codes/models/audio/music/flat_diffusion.py @@ -146,7 +146,6 @@ class FlatDiffusion(nn.Module): # complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive # transformer network. self.embeddings = nn.ModuleList([nn.Embedding(in_vectors, model_channels//in_groups) for _ in range(in_groups)]) - self.code_norm = normalization(model_channels) self.latent_conditioner = nn.Sequential( nn.Conv1d(in_latent_channels, model_channels, 3, padding=1), AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), @@ -154,6 +153,12 @@ class FlatDiffusion(nn.Module): AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), ) + self.code_converter = nn.Sequential( + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + ) + self.code_norm = normalization(model_channels) self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2), nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2), AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), @@ -183,7 +188,7 @@ class FlatDiffusion(nn.Module): groups = { 'minicoder': list(self.contextual_embedder.parameters()), 'layers': list(self.layers.parameters()), - 'code_converters': list(self.embeddings.parameters()) + list(self.latent_conditioner.parameters()), + 'code_converters': list(self.embeddings.parameters()) + list(self.code_converter.parameters()) + list(self.latent_conditioner.parameters()), 'timestep_integrator': list(self.conditioning_timestep_integrator.parameters()) + list(self.integrating_conv.parameters()), 'time_embed': list(self.time_embed.parameters()), } @@ -218,6 +223,7 @@ class FlatDiffusion(nn.Module): code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1), code_emb) expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest') + expanded_code_emb = self.code_converter(expanded_code_emb) if not return_code_pred: return expanded_code_emb @@ -311,5 +317,5 @@ if __name__ == '__main__': # Test with latent aligned conditioning #o = model(clip, ts, aligned_latent, cond) # Test with sequence aligned conditioning - o = model(clip, ts, aligned_sequence, cond) + o = model(clip, ts, aligned_sequence, cond, return_code_pred=True)