From e735d8e1fa70eb7b66e3052e282ef40b764247c5 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 21 Mar 2022 14:44:00 -0600 Subject: [PATCH] unified_voice fixes --- codes/models/audio/tts/unified_voice2.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/codes/models/audio/tts/unified_voice2.py b/codes/models/audio/tts/unified_voice2.py index f4a3deec..78beb2d0 100644 --- a/codes/models/audio/tts/unified_voice2.py +++ b/codes/models/audio/tts/unified_voice2.py @@ -434,6 +434,8 @@ class UnifiedVoice(nn.Module): for j in range(speech_conditioning_input.shape[1]): conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) conds = torch.stack(conds, dim=1) + if self.average_conditioning_embeddings: + conds = conds.mean(dim=1).unsqueeze(1) text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + self.text_solo_embedding @@ -460,6 +462,8 @@ class UnifiedVoice(nn.Module): for j in range(speech_conditioning_input.shape[1]): conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) conds = torch.stack(conds, dim=1) + if self.average_conditioning_embeddings: + conds = conds.mean(dim=1).unsqueeze(1) mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token) if raw_mels is not None: @@ -496,6 +500,8 @@ class UnifiedVoice(nn.Module): for j in range(speech_conditioning_input.shape[1]): conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) conds = torch.stack(conds, dim=1) + if self.average_conditioning_embeddings: + conds = conds.mean(dim=1).unsqueeze(1) emb = torch.cat([conds, text_emb], dim=1) self.inference_model.store_mel_emb(emb)