fix conditioning_free signal

This commit is contained in:
James Betker 2022-03-21 15:29:17 -06:00
parent 2a65c982ca
commit 9c7598dc9a

View File

@ -158,7 +158,7 @@ class DiffusionTtsFlat(nn.Module):
# Note: this block does not need to repeated on inference, since it is not timestep-dependent or x-dependent.
unused_params = []
if conditioning_free:
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1)
code_emb = self.unconditioned_embedding.repeat(conditioning_input.shape[0], 1, 1)
else:
unused_params.append(self.unconditioned_embedding)
cond_emb = self.contextual_embedder(conditioning_input)