fixed aligned_latent
This commit is contained in:
parent
1609101a42
commit
e9bb692490
|
@ -187,8 +187,6 @@ class MusicGenerator(nn.Module):
|
|||
|
||||
|
||||
def timestep_independent(self, aligned_conditioning, expected_seq_len, return_code_pred):
|
||||
# Shuffle aligned_latent to BxCxS format
|
||||
aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
|
||||
code_emb = self.conditioner(aligned_conditioning)
|
||||
unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
|
||||
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||
|
@ -269,7 +267,7 @@ def register_music_gap_gen(opt_net, opt):
|
|||
|
||||
if __name__ == '__main__':
|
||||
clip = torch.randn(2, 100, 400)
|
||||
aligned_latent = torch.randn(2,388,100)
|
||||
aligned_latent = torch.randn(2,100,388)
|
||||
ts = torch.LongTensor([600, 600])
|
||||
model = MusicGenerator(512, layer_drop=.3, unconditioned_percentage=.5)
|
||||
o = model(clip, ts, aligned_latent)
|
||||
|
|
Loading…
Reference in New Issue
Block a user