forked from mrq/DL-Art-School
take the conditioning mean rather than the first element
This commit is contained in:
parent
9c7598dc9a
commit
9e97cd800c
|
@ -161,7 +161,14 @@ class DiffusionTtsFlat(nn.Module):
|
|||
code_emb = self.unconditioned_embedding.repeat(conditioning_input.shape[0], 1, 1)
|
||||
else:
|
||||
unused_params.append(self.unconditioned_embedding)
|
||||
cond_emb = self.contextual_embedder(conditioning_input)
|
||||
|
||||
speech_conditioning_input = conditioning_input.unsqueeze(1) if len(conditioning_input.shape) == 3 else conditioning_input
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.contextual_embedder(speech_conditioning_input[:, j]))
|
||||
conds = torch.stack(conds, dim=1)
|
||||
cond_emb = conds.mean(dim=1)
|
||||
|
||||
if len(cond_emb.shape) == 3: # Just take the first element.
|
||||
cond_emb = cond_emb[:, :, 0]
|
||||
if is_latent(aligned_conditioning):
|
||||
|
@ -224,7 +231,7 @@ if __name__ == '__main__':
|
|||
clip = torch.randn(2, 100, 400)
|
||||
aligned_latent = torch.randn(2,388,512)
|
||||
aligned_sequence = torch.randint(0,8192,(2,388))
|
||||
cond = torch.randn(2, 100, 400)
|
||||
cond = torch.randn(2, 2, 100, 400)
|
||||
ts = torch.LongTensor([600, 600])
|
||||
model = DiffusionTtsFlat(512, layer_drop=.3)
|
||||
# Test with latent aligned conditioning
|
||||
|
|
Loading…
Reference in New Issue
Block a user