diff --git a/codes/models/audio/tts/unified_voice2.py b/codes/models/audio/tts/unified_voice2.py index 02fbf956..eb481c91 100644 --- a/codes/models/audio/tts/unified_voice2.py +++ b/codes/models/audio/tts/unified_voice2.py @@ -365,6 +365,17 @@ class UnifiedVoice(nn.Module): else: return first_logits + + def get_conditioning_latent(self, speech_conditioning_input): + speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input + conds = [] + for j in range(speech_conditioning_input.shape[1]): + conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) + conds = torch.stack(conds, dim=1) + if self.average_conditioning_embeddings: + conds = conds.mean(dim=1).unsqueeze(1) + return conds + def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, raw_mels=None, return_attentions=False, return_latent=False, clip_inputs=True): """ @@ -399,13 +410,7 @@ class UnifiedVoice(nn.Module): text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token) mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token) - speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input - conds = [] - for j in range(speech_conditioning_input.shape[1]): - conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) - conds = torch.stack(conds, dim=1) - if self.average_conditioning_embeddings: - conds = conds.mean(dim=1).unsqueeze(1) + conds = self.get_conditioning_latent(speech_conditioning_input) text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) diff --git a/codes/train.py b/codes/train.py index e2601790..f13e8e55 100644 --- a/codes/train.py +++ b/codes/train.py @@ -327,7 +327,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_tortoise_reverse_classifier.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_tortoise_random_latent_gen.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() opt = option.parse(args.opt, is_train=True)