fixed aligned_latent
This commit is contained in:
parent
1609101a42
commit
e9bb692490
|
@ -187,8 +187,6 @@ class MusicGenerator(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
def timestep_independent(self, aligned_conditioning, expected_seq_len, return_code_pred):
|
def timestep_independent(self, aligned_conditioning, expected_seq_len, return_code_pred):
|
||||||
# Shuffle aligned_latent to BxCxS format
|
|
||||||
aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
|
|
||||||
code_emb = self.conditioner(aligned_conditioning)
|
code_emb = self.conditioner(aligned_conditioning)
|
||||||
unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
|
unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
|
||||||
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||||
|
@ -269,7 +267,7 @@ def register_music_gap_gen(opt_net, opt):
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
clip = torch.randn(2, 100, 400)
|
clip = torch.randn(2, 100, 400)
|
||||||
aligned_latent = torch.randn(2,388,100)
|
aligned_latent = torch.randn(2,100,388)
|
||||||
ts = torch.LongTensor([600, 600])
|
ts = torch.LongTensor([600, 600])
|
||||||
model = MusicGenerator(512, layer_drop=.3, unconditioned_percentage=.5)
|
model = MusicGenerator(512, layer_drop=.3, unconditioned_percentage=.5)
|
||||||
o = model(clip, ts, aligned_latent)
|
o = model(clip, ts, aligned_latent)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user