forked from mrq/DL-Art-School
fix bug
This commit is contained in:
parent
035bcd9f6c
commit
8623c51902
|
@ -200,7 +200,7 @@ class DiffusionTtsFlat(nn.Module):
|
|||
}
|
||||
return groups
|
||||
|
||||
def timestep_independent(self, aligned_conditioning, conditioning_input, return_code_pred):
|
||||
def timestep_independent(self, aligned_conditioning, conditioning_input, expected_seq_len, return_code_pred):
|
||||
# Shuffle aligned_latent to BxCxS format
|
||||
if is_latent(aligned_conditioning):
|
||||
aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
|
||||
|
@ -228,7 +228,7 @@ class DiffusionTtsFlat(nn.Module):
|
|||
device=code_emb.device) < self.unconditioned_percentage
|
||||
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1),
|
||||
code_emb)
|
||||
expanded_code_emb = F.interpolate(code_emb, size=aligned_conditioning.shape[-1]*4, mode='nearest')
|
||||
expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest')
|
||||
|
||||
if not return_code_pred:
|
||||
return expanded_code_emb
|
||||
|
@ -263,7 +263,7 @@ class DiffusionTtsFlat(nn.Module):
|
|||
if precomputed_aligned_embeddings is not None:
|
||||
code_emb = precomputed_aligned_embeddings
|
||||
else:
|
||||
code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_input, True)
|
||||
code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_input, x.shape[-1], True)
|
||||
|
||||
unused_params.append(self.unconditioned_embedding)
|
||||
if is_latent(aligned_conditioning):
|
||||
|
|
Loading…
Reference in New Issue
Block a user