remove mel_pred

This commit is contained in:
James Betker 2022-05-06 00:24:05 -06:00
parent e9bb692490
commit d5fb79564a

View File

@ -186,27 +186,19 @@ class MusicGenerator(nn.Module):
return truth * mask return truth * mask
def timestep_independent(self, aligned_conditioning, expected_seq_len, return_code_pred): def timestep_independent(self, truth, expected_seq_len, return_code_pred):
code_emb = self.conditioner(aligned_conditioning) code_emb = self.conditioner(truth)
unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
if self.training and self.unconditioned_percentage > 0: if self.training and self.unconditioned_percentage > 0:
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1), unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
device=code_emb.device) < self.unconditioned_percentage device=code_emb.device) < self.unconditioned_percentage
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1), code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(truth.shape[0], 1, 1),
code_emb) code_emb)
expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest') expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest')
if not return_code_pred:
return expanded_code_emb return expanded_code_emb
else:
mel_pred = self.mel_head(expanded_code_emb)
# Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches. This is because we don't want that gradient being used to train parameters through the codes_embedder as it unbalances contributions to that network from the MSE loss.
mel_pred = mel_pred * unconditioned_batches.logical_not()
return expanded_code_emb, mel_pred
def forward(self, x, timesteps, truth=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False): def forward(self, x, timesteps, truth=None, precomputed_aligned_embeddings=None, conditioning_free=False):
""" """
Apply the model to an input batch. Apply the model to an input batch.
@ -218,7 +210,6 @@ class MusicGenerator(nn.Module):
:return: an [N x C x ...] Tensor of outputs. :return: an [N x C x ...] Tensor of outputs.
""" """
assert precomputed_aligned_embeddings is not None or truth is not None assert precomputed_aligned_embeddings is not None or truth is not None
assert not (return_code_pred and precomputed_aligned_embeddings is not None) # These two are mutually exclusive.
unused_params = [] unused_params = []
if conditioning_free: if conditioning_free:
@ -229,7 +220,7 @@ class MusicGenerator(nn.Module):
code_emb = precomputed_aligned_embeddings code_emb = precomputed_aligned_embeddings
else: else:
truth = self.do_masking(truth) truth = self.do_masking(truth)
code_emb, mel_pred = self.timestep_independent(truth, x.shape[-1], True) code_emb = self.timestep_independent(truth, x.shape[-1], True)
unused_params.append(self.unconditioned_embedding) unused_params.append(self.unconditioned_embedding)
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
@ -255,8 +246,6 @@ class MusicGenerator(nn.Module):
extraneous_addition = extraneous_addition + p.mean() extraneous_addition = extraneous_addition + p.mean()
out = out + extraneous_addition * 0 out = out + extraneous_addition * 0
if return_code_pred:
return out, mel_pred
return out return out