diff --git a/codes/models/audio/music/music_gen_fill_gaps.py b/codes/models/audio/music/music_gen_fill_gaps.py index 2605f5f2..45dd0624 100644 --- a/codes/models/audio/music/music_gen_fill_gaps.py +++ b/codes/models/audio/music/music_gen_fill_gaps.py @@ -186,27 +186,19 @@ class MusicGenerator(nn.Module): return truth * mask - def timestep_independent(self, aligned_conditioning, expected_seq_len, return_code_pred): - code_emb = self.conditioner(aligned_conditioning) - unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device) + def timestep_independent(self, truth, expected_seq_len, return_code_pred): + code_emb = self.conditioner(truth) # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. if self.training and self.unconditioned_percentage > 0: unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1), device=code_emb.device) < self.unconditioned_percentage - code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1), + code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(truth.shape[0], 1, 1), code_emb) expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest') - - if not return_code_pred: - return expanded_code_emb - else: - mel_pred = self.mel_head(expanded_code_emb) - # Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches. This is because we don't want that gradient being used to train parameters through the codes_embedder as it unbalances contributions to that network from the MSE loss. - mel_pred = mel_pred * unconditioned_batches.logical_not() - return expanded_code_emb, mel_pred + return expanded_code_emb - def forward(self, x, timesteps, truth=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False): + def forward(self, x, timesteps, truth=None, precomputed_aligned_embeddings=None, conditioning_free=False): """ Apply the model to an input batch. @@ -218,7 +210,6 @@ class MusicGenerator(nn.Module): :return: an [N x C x ...] Tensor of outputs. """ assert precomputed_aligned_embeddings is not None or truth is not None - assert not (return_code_pred and precomputed_aligned_embeddings is not None) # These two are mutually exclusive. unused_params = [] if conditioning_free: @@ -229,7 +220,7 @@ class MusicGenerator(nn.Module): code_emb = precomputed_aligned_embeddings else: truth = self.do_masking(truth) - code_emb, mel_pred = self.timestep_independent(truth, x.shape[-1], True) + code_emb = self.timestep_independent(truth, x.shape[-1], True) unused_params.append(self.unconditioned_embedding) time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) @@ -255,8 +246,6 @@ class MusicGenerator(nn.Module): extraneous_addition = extraneous_addition + p.mean() out = out + extraneous_addition * 0 - if return_code_pred: - return out, mel_pred return out