break out get_conditioning_latent from unified_voice
This commit is contained in:
parent
afa2df57c9
commit
b712d3b72b
|
@ -365,6 +365,17 @@ class UnifiedVoice(nn.Module):
|
|||
else:
|
||||
return first_logits
|
||||
|
||||
|
||||
def get_conditioning_latent(self, speech_conditioning_input):
|
||||
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||
conds = torch.stack(conds, dim=1)
|
||||
if self.average_conditioning_embeddings:
|
||||
conds = conds.mean(dim=1).unsqueeze(1)
|
||||
return conds
|
||||
|
||||
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, raw_mels=None, return_attentions=False,
|
||||
return_latent=False, clip_inputs=True):
|
||||
"""
|
||||
|
@ -399,13 +410,7 @@ class UnifiedVoice(nn.Module):
|
|||
text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token)
|
||||
mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token)
|
||||
|
||||
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||
conds = torch.stack(conds, dim=1)
|
||||
if self.average_conditioning_embeddings:
|
||||
conds = conds.mean(dim=1).unsqueeze(1)
|
||||
conds = self.get_conditioning_latent(speech_conditioning_input)
|
||||
|
||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
|
|
|
@ -327,7 +327,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_tortoise_reverse_classifier.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_tortoise_random_latent_gen.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
|
|
Loading…
Reference in New Issue
Block a user