take the conditioning mean rather than the first element
This commit is contained in:
parent
9c7598dc9a
commit
9e97cd800c
|
@ -161,7 +161,14 @@ class DiffusionTtsFlat(nn.Module):
|
||||||
code_emb = self.unconditioned_embedding.repeat(conditioning_input.shape[0], 1, 1)
|
code_emb = self.unconditioned_embedding.repeat(conditioning_input.shape[0], 1, 1)
|
||||||
else:
|
else:
|
||||||
unused_params.append(self.unconditioned_embedding)
|
unused_params.append(self.unconditioned_embedding)
|
||||||
cond_emb = self.contextual_embedder(conditioning_input)
|
|
||||||
|
speech_conditioning_input = conditioning_input.unsqueeze(1) if len(conditioning_input.shape) == 3 else conditioning_input
|
||||||
|
conds = []
|
||||||
|
for j in range(speech_conditioning_input.shape[1]):
|
||||||
|
conds.append(self.contextual_embedder(speech_conditioning_input[:, j]))
|
||||||
|
conds = torch.stack(conds, dim=1)
|
||||||
|
cond_emb = conds.mean(dim=1)
|
||||||
|
|
||||||
if len(cond_emb.shape) == 3: # Just take the first element.
|
if len(cond_emb.shape) == 3: # Just take the first element.
|
||||||
cond_emb = cond_emb[:, :, 0]
|
cond_emb = cond_emb[:, :, 0]
|
||||||
if is_latent(aligned_conditioning):
|
if is_latent(aligned_conditioning):
|
||||||
|
@ -224,7 +231,7 @@ if __name__ == '__main__':
|
||||||
clip = torch.randn(2, 100, 400)
|
clip = torch.randn(2, 100, 400)
|
||||||
aligned_latent = torch.randn(2,388,512)
|
aligned_latent = torch.randn(2,388,512)
|
||||||
aligned_sequence = torch.randint(0,8192,(2,388))
|
aligned_sequence = torch.randint(0,8192,(2,388))
|
||||||
cond = torch.randn(2, 100, 400)
|
cond = torch.randn(2, 2, 100, 400)
|
||||||
ts = torch.LongTensor([600, 600])
|
ts = torch.LongTensor([600, 600])
|
||||||
model = DiffusionTtsFlat(512, layer_drop=.3)
|
model = DiffusionTtsFlat(512, layer_drop=.3)
|
||||||
# Test with latent aligned conditioning
|
# Test with latent aligned conditioning
|
||||||
|
|
Loading…
Reference in New Issue
Block a user