This commit is contained in:
James Betker 2022-04-01 16:11:34 -06:00
parent 035bcd9f6c
commit 8623c51902

View File

@ -200,7 +200,7 @@ class DiffusionTtsFlat(nn.Module):
}
return groups
def timestep_independent(self, aligned_conditioning, conditioning_input, return_code_pred):
def timestep_independent(self, aligned_conditioning, conditioning_input, expected_seq_len, return_code_pred):
# Shuffle aligned_latent to BxCxS format
if is_latent(aligned_conditioning):
aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
@ -228,7 +228,7 @@ class DiffusionTtsFlat(nn.Module):
device=code_emb.device) < self.unconditioned_percentage
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1),
code_emb)
expanded_code_emb = F.interpolate(code_emb, size=aligned_conditioning.shape[-1]*4, mode='nearest')
expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest')
if not return_code_pred:
return expanded_code_emb
@ -263,7 +263,7 @@ class DiffusionTtsFlat(nn.Module):
if precomputed_aligned_embeddings is not None:
code_emb = precomputed_aligned_embeddings
else:
code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_input, True)
code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_input, x.shape[-1], True)
unused_params.append(self.unconditioned_embedding)
if is_latent(aligned_conditioning):