forked from mrq/DL-Art-School
unified_voice fixes
This commit is contained in:
parent
1ad18d29a8
commit
e735d8e1fa
|
@ -434,6 +434,8 @@ class UnifiedVoice(nn.Module):
|
||||||
for j in range(speech_conditioning_input.shape[1]):
|
for j in range(speech_conditioning_input.shape[1]):
|
||||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||||
conds = torch.stack(conds, dim=1)
|
conds = torch.stack(conds, dim=1)
|
||||||
|
if self.average_conditioning_embeddings:
|
||||||
|
conds = conds.mean(dim=1).unsqueeze(1)
|
||||||
|
|
||||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + self.text_solo_embedding
|
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + self.text_solo_embedding
|
||||||
|
@ -460,6 +462,8 @@ class UnifiedVoice(nn.Module):
|
||||||
for j in range(speech_conditioning_input.shape[1]):
|
for j in range(speech_conditioning_input.shape[1]):
|
||||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||||
conds = torch.stack(conds, dim=1)
|
conds = torch.stack(conds, dim=1)
|
||||||
|
if self.average_conditioning_embeddings:
|
||||||
|
conds = conds.mean(dim=1).unsqueeze(1)
|
||||||
|
|
||||||
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
|
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
|
||||||
if raw_mels is not None:
|
if raw_mels is not None:
|
||||||
|
@ -496,6 +500,8 @@ class UnifiedVoice(nn.Module):
|
||||||
for j in range(speech_conditioning_input.shape[1]):
|
for j in range(speech_conditioning_input.shape[1]):
|
||||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||||
conds = torch.stack(conds, dim=1)
|
conds = torch.stack(conds, dim=1)
|
||||||
|
if self.average_conditioning_embeddings:
|
||||||
|
conds = conds.mean(dim=1).unsqueeze(1)
|
||||||
|
|
||||||
emb = torch.cat([conds, text_emb], dim=1)
|
emb = torch.cat([conds, text_emb], dim=1)
|
||||||
self.inference_model.store_mel_emb(emb)
|
self.inference_model.store_mel_emb(emb)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user