remove mel_pred
This commit is contained in:
parent
e9bb692490
commit
d5fb79564a
|
@ -186,27 +186,19 @@ class MusicGenerator(nn.Module):
|
|||
return truth * mask
|
||||
|
||||
|
||||
def timestep_independent(self, aligned_conditioning, expected_seq_len, return_code_pred):
|
||||
code_emb = self.conditioner(aligned_conditioning)
|
||||
unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
|
||||
def timestep_independent(self, truth, expected_seq_len, return_code_pred):
|
||||
code_emb = self.conditioner(truth)
|
||||
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||
if self.training and self.unconditioned_percentage > 0:
|
||||
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
|
||||
device=code_emb.device) < self.unconditioned_percentage
|
||||
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1),
|
||||
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(truth.shape[0], 1, 1),
|
||||
code_emb)
|
||||
expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest')
|
||||
|
||||
if not return_code_pred:
|
||||
return expanded_code_emb
|
||||
else:
|
||||
mel_pred = self.mel_head(expanded_code_emb)
|
||||
# Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches. This is because we don't want that gradient being used to train parameters through the codes_embedder as it unbalances contributions to that network from the MSE loss.
|
||||
mel_pred = mel_pred * unconditioned_batches.logical_not()
|
||||
return expanded_code_emb, mel_pred
|
||||
return expanded_code_emb
|
||||
|
||||
|
||||
def forward(self, x, timesteps, truth=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False):
|
||||
def forward(self, x, timesteps, truth=None, precomputed_aligned_embeddings=None, conditioning_free=False):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
||||
|
@ -218,7 +210,6 @@ class MusicGenerator(nn.Module):
|
|||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
assert precomputed_aligned_embeddings is not None or truth is not None
|
||||
assert not (return_code_pred and precomputed_aligned_embeddings is not None) # These two are mutually exclusive.
|
||||
|
||||
unused_params = []
|
||||
if conditioning_free:
|
||||
|
@ -229,7 +220,7 @@ class MusicGenerator(nn.Module):
|
|||
code_emb = precomputed_aligned_embeddings
|
||||
else:
|
||||
truth = self.do_masking(truth)
|
||||
code_emb, mel_pred = self.timestep_independent(truth, x.shape[-1], True)
|
||||
code_emb = self.timestep_independent(truth, x.shape[-1], True)
|
||||
unused_params.append(self.unconditioned_embedding)
|
||||
|
||||
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
|
@ -255,8 +246,6 @@ class MusicGenerator(nn.Module):
|
|||
extraneous_addition = extraneous_addition + p.mean()
|
||||
out = out + extraneous_addition * 0
|
||||
|
||||
if return_code_pred:
|
||||
return out, mel_pred
|
||||
return out
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user