This commit is contained in:
James Betker 2022-04-01 16:11:34 -06:00
parent 035bcd9f6c
commit 8623c51902

View File

@ -200,7 +200,7 @@ class DiffusionTtsFlat(nn.Module):
} }
return groups return groups
def timestep_independent(self, aligned_conditioning, conditioning_input, return_code_pred): def timestep_independent(self, aligned_conditioning, conditioning_input, expected_seq_len, return_code_pred):
# Shuffle aligned_latent to BxCxS format # Shuffle aligned_latent to BxCxS format
if is_latent(aligned_conditioning): if is_latent(aligned_conditioning):
aligned_conditioning = aligned_conditioning.permute(0, 2, 1) aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
@ -228,7 +228,7 @@ class DiffusionTtsFlat(nn.Module):
device=code_emb.device) < self.unconditioned_percentage device=code_emb.device) < self.unconditioned_percentage
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1), code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1),
code_emb) code_emb)
expanded_code_emb = F.interpolate(code_emb, size=aligned_conditioning.shape[-1]*4, mode='nearest') expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest')
if not return_code_pred: if not return_code_pred:
return expanded_code_emb return expanded_code_emb
@ -263,7 +263,7 @@ class DiffusionTtsFlat(nn.Module):
if precomputed_aligned_embeddings is not None: if precomputed_aligned_embeddings is not None:
code_emb = precomputed_aligned_embeddings code_emb = precomputed_aligned_embeddings
else: else:
code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_input, True) code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_input, x.shape[-1], True)
unused_params.append(self.unconditioned_embedding) unused_params.append(self.unconditioned_embedding)
if is_latent(aligned_conditioning): if is_latent(aligned_conditioning):