mel_head should be optional

This commit is contained in:
James Betker 2022-05-22 12:25:45 -06:00
parent 37640e9759
commit 2dd0b9e6e9

View File

@ -271,6 +271,8 @@ class FlatDiffusion(nn.Module):
assert not (return_code_pred and precomputed_code_embeddings is not None), "I cannot compute a code_pred output for you."
unused_params = []
if not return_code_pred:
unused_params.extend(list(self.mel_head.parameters()))
if conditioning_free:
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))