remove mel_pred
This commit is contained in:
parent
e9bb692490
commit
d5fb79564a
|
@ -186,27 +186,19 @@ class MusicGenerator(nn.Module):
|
||||||
return truth * mask
|
return truth * mask
|
||||||
|
|
||||||
|
|
||||||
def timestep_independent(self, aligned_conditioning, expected_seq_len, return_code_pred):
|
def timestep_independent(self, truth, expected_seq_len, return_code_pred):
|
||||||
code_emb = self.conditioner(aligned_conditioning)
|
code_emb = self.conditioner(truth)
|
||||||
unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
|
|
||||||
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||||
if self.training and self.unconditioned_percentage > 0:
|
if self.training and self.unconditioned_percentage > 0:
|
||||||
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
|
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
|
||||||
device=code_emb.device) < self.unconditioned_percentage
|
device=code_emb.device) < self.unconditioned_percentage
|
||||||
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1),
|
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(truth.shape[0], 1, 1),
|
||||||
code_emb)
|
code_emb)
|
||||||
expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest')
|
expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest')
|
||||||
|
|
||||||
if not return_code_pred:
|
|
||||||
return expanded_code_emb
|
return expanded_code_emb
|
||||||
else:
|
|
||||||
mel_pred = self.mel_head(expanded_code_emb)
|
|
||||||
# Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches. This is because we don't want that gradient being used to train parameters through the codes_embedder as it unbalances contributions to that network from the MSE loss.
|
|
||||||
mel_pred = mel_pred * unconditioned_batches.logical_not()
|
|
||||||
return expanded_code_emb, mel_pred
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x, timesteps, truth=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False):
|
def forward(self, x, timesteps, truth=None, precomputed_aligned_embeddings=None, conditioning_free=False):
|
||||||
"""
|
"""
|
||||||
Apply the model to an input batch.
|
Apply the model to an input batch.
|
||||||
|
|
||||||
|
@ -218,7 +210,6 @@ class MusicGenerator(nn.Module):
|
||||||
:return: an [N x C x ...] Tensor of outputs.
|
:return: an [N x C x ...] Tensor of outputs.
|
||||||
"""
|
"""
|
||||||
assert precomputed_aligned_embeddings is not None or truth is not None
|
assert precomputed_aligned_embeddings is not None or truth is not None
|
||||||
assert not (return_code_pred and precomputed_aligned_embeddings is not None) # These two are mutually exclusive.
|
|
||||||
|
|
||||||
unused_params = []
|
unused_params = []
|
||||||
if conditioning_free:
|
if conditioning_free:
|
||||||
|
@ -229,7 +220,7 @@ class MusicGenerator(nn.Module):
|
||||||
code_emb = precomputed_aligned_embeddings
|
code_emb = precomputed_aligned_embeddings
|
||||||
else:
|
else:
|
||||||
truth = self.do_masking(truth)
|
truth = self.do_masking(truth)
|
||||||
code_emb, mel_pred = self.timestep_independent(truth, x.shape[-1], True)
|
code_emb = self.timestep_independent(truth, x.shape[-1], True)
|
||||||
unused_params.append(self.unconditioned_embedding)
|
unused_params.append(self.unconditioned_embedding)
|
||||||
|
|
||||||
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||||
|
@ -255,8 +246,6 @@ class MusicGenerator(nn.Module):
|
||||||
extraneous_addition = extraneous_addition + p.mean()
|
extraneous_addition = extraneous_addition + p.mean()
|
||||||
out = out + extraneous_addition * 0
|
out = out + extraneous_addition * 0
|
||||||
|
|
||||||
if return_code_pred:
|
|
||||||
return out, mel_pred
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user