take the conditioning mean rather than the first element

This commit is contained in:
James Betker 2022-03-21 16:58:03 -06:00
parent 9c7598dc9a
commit 9e97cd800c

View File

@ -161,7 +161,14 @@ class DiffusionTtsFlat(nn.Module):
code_emb = self.unconditioned_embedding.repeat(conditioning_input.shape[0], 1, 1)
else:
unused_params.append(self.unconditioned_embedding)
cond_emb = self.contextual_embedder(conditioning_input)
speech_conditioning_input = conditioning_input.unsqueeze(1) if len(conditioning_input.shape) == 3 else conditioning_input
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.contextual_embedder(speech_conditioning_input[:, j]))
conds = torch.stack(conds, dim=1)
cond_emb = conds.mean(dim=1)
if len(cond_emb.shape) == 3: # Just take the first element.
cond_emb = cond_emb[:, :, 0]
if is_latent(aligned_conditioning):
@ -224,7 +231,7 @@ if __name__ == '__main__':
clip = torch.randn(2, 100, 400)
aligned_latent = torch.randn(2,388,512)
aligned_sequence = torch.randint(0,8192,(2,388))
cond = torch.randn(2, 100, 400)
cond = torch.randn(2, 2, 100, 400)
ts = torch.LongTensor([600, 600])
model = DiffusionTtsFlat(512, layer_drop=.3)
# Test with latent aligned conditioning