From eda753e77612020a9ff98ee26207582e0a23f5b7 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 31 Dec 2021 23:32:08 -0700 Subject: [PATCH] Allow conditioning shuffling to be disabled --- codes/models/gpt_voice/unified_voice.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/codes/models/gpt_voice/unified_voice.py b/codes/models/gpt_voice/unified_voice.py index e9a08964..235a1bd8 100644 --- a/codes/models/gpt_voice/unified_voice.py +++ b/codes/models/gpt_voice/unified_voice.py @@ -50,7 +50,7 @@ class UnifiedGptVoice(nn.Module): def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=120, max_mel_tokens=250, max_total_tokens=370, max_conditioning_inputs=3, checkpointing=True, mel_length_compression=1024, max_conditioning_length=60, number_text_tokens=256, start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192, - stop_mel_token=8193, use_dedicated_position_embeddings_for_paired=True): + stop_mel_token=8193, use_dedicated_position_embeddings_for_paired=True, shuffle_conditioning=True): super().__init__() self.number_text_tokens = number_text_tokens @@ -59,6 +59,7 @@ class UnifiedGptVoice(nn.Module): self.number_mel_codes = number_mel_codes self.start_mel_token = start_mel_token self.stop_mel_token = stop_mel_token + self.shuffle_conditioning = shuffle_conditioning self.max_mel_tokens = max_mel_tokens self.max_symbols_per_phrase = max_symbols_per_phrase @@ -171,7 +172,8 @@ class UnifiedGptVoice(nn.Module): assert self.max_total_tokens >= mel_inputs.shape[1] + text_inputs.shape[1], f'{mel_inputs.shape[1]}, {text_inputs.shape[1]}' mel_inputs = self.set_mel_padding(mel_inputs, wav_lengths) - speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) + if self.shuffle_conditioning: + speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) @@ -197,7 +199,8 @@ class UnifiedGptVoice(nn.Module): """ assert self.max_symbols_per_phrase >= text_inputs.shape[1], f'{text_inputs.shape[1]}' - speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) + if self.shuffle_conditioning: + speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) @@ -213,7 +216,8 @@ class UnifiedGptVoice(nn.Module): assert self.max_mel_tokens >= mel_inputs.shape[1], f'{mel_inputs.shape[1]}' mel_inputs = self.set_mel_padding(mel_inputs, wav_lengths) - speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) + if self.shuffle_conditioning: + speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.start_mel_token, self.stop_mel_token) @@ -230,8 +234,9 @@ class UnifiedGptVoice(nn.Module): text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) text_emb = self.text_embedding(text_inputs) + self.text_pos_paired_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) - # Randomly permute the conditioning spectrogram, to destroy any structure present. - speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) + if self.shuffle_conditioning: + # Randomly permute the conditioning spectrogram, to destroy any structure present. + speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) cond = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) emb = torch.cat([cond, text_emb], dim=1)