From e9bb692490eca27875b53f6be2e711667036da7d Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 6 May 2022 00:20:21 -0600 Subject: [PATCH] fixed aligned_latent --- codes/models/audio/music/music_gen_fill_gaps.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/codes/models/audio/music/music_gen_fill_gaps.py b/codes/models/audio/music/music_gen_fill_gaps.py index 91041c4c..2605f5f2 100644 --- a/codes/models/audio/music/music_gen_fill_gaps.py +++ b/codes/models/audio/music/music_gen_fill_gaps.py @@ -187,8 +187,6 @@ class MusicGenerator(nn.Module): def timestep_independent(self, aligned_conditioning, expected_seq_len, return_code_pred): - # Shuffle aligned_latent to BxCxS format - aligned_conditioning = aligned_conditioning.permute(0, 2, 1) code_emb = self.conditioner(aligned_conditioning) unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device) # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. @@ -269,7 +267,7 @@ def register_music_gap_gen(opt_net, opt): if __name__ == '__main__': clip = torch.randn(2, 100, 400) - aligned_latent = torch.randn(2,388,100) + aligned_latent = torch.randn(2,100,388) ts = torch.LongTensor([600, 600]) model = MusicGenerator(512, layer_drop=.3, unconditioned_percentage=.5) o = model(clip, ts, aligned_latent)