fixed aligned_latent

This commit is contained in:
James Betker 2022-05-06 00:20:21 -06:00
parent 1609101a42
commit e9bb692490

View File

@ -187,8 +187,6 @@ class MusicGenerator(nn.Module):
def timestep_independent(self, aligned_conditioning, expected_seq_len, return_code_pred): def timestep_independent(self, aligned_conditioning, expected_seq_len, return_code_pred):
# Shuffle aligned_latent to BxCxS format
aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
code_emb = self.conditioner(aligned_conditioning) code_emb = self.conditioner(aligned_conditioning)
unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device) unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
@ -269,7 +267,7 @@ def register_music_gap_gen(opt_net, opt):
if __name__ == '__main__': if __name__ == '__main__':
clip = torch.randn(2, 100, 400) clip = torch.randn(2, 100, 400)
aligned_latent = torch.randn(2,388,100) aligned_latent = torch.randn(2,100,388)
ts = torch.LongTensor([600, 600]) ts = torch.LongTensor([600, 600])
model = MusicGenerator(512, layer_drop=.3, unconditioned_percentage=.5) model = MusicGenerator(512, layer_drop=.3, unconditioned_percentage=.5)
o = model(clip, ts, aligned_latent) o = model(clip, ts, aligned_latent)