forked from mrq/DL-Art-School
Allow conditioning shuffling to be disabled
This commit is contained in:
parent
17fb934575
commit
eda753e776
|
@ -50,7 +50,7 @@ class UnifiedGptVoice(nn.Module):
|
||||||
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=120, max_mel_tokens=250, max_total_tokens=370, max_conditioning_inputs=3,
|
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=120, max_mel_tokens=250, max_total_tokens=370, max_conditioning_inputs=3,
|
||||||
checkpointing=True, mel_length_compression=1024, max_conditioning_length=60, number_text_tokens=256,
|
checkpointing=True, mel_length_compression=1024, max_conditioning_length=60, number_text_tokens=256,
|
||||||
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
|
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
|
||||||
stop_mel_token=8193, use_dedicated_position_embeddings_for_paired=True):
|
stop_mel_token=8193, use_dedicated_position_embeddings_for_paired=True, shuffle_conditioning=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.number_text_tokens = number_text_tokens
|
self.number_text_tokens = number_text_tokens
|
||||||
|
@ -59,6 +59,7 @@ class UnifiedGptVoice(nn.Module):
|
||||||
self.number_mel_codes = number_mel_codes
|
self.number_mel_codes = number_mel_codes
|
||||||
self.start_mel_token = start_mel_token
|
self.start_mel_token = start_mel_token
|
||||||
self.stop_mel_token = stop_mel_token
|
self.stop_mel_token = stop_mel_token
|
||||||
|
self.shuffle_conditioning = shuffle_conditioning
|
||||||
|
|
||||||
self.max_mel_tokens = max_mel_tokens
|
self.max_mel_tokens = max_mel_tokens
|
||||||
self.max_symbols_per_phrase = max_symbols_per_phrase
|
self.max_symbols_per_phrase = max_symbols_per_phrase
|
||||||
|
@ -171,6 +172,7 @@ class UnifiedGptVoice(nn.Module):
|
||||||
assert self.max_total_tokens >= mel_inputs.shape[1] + text_inputs.shape[1], f'{mel_inputs.shape[1]}, {text_inputs.shape[1]}'
|
assert self.max_total_tokens >= mel_inputs.shape[1] + text_inputs.shape[1], f'{mel_inputs.shape[1]}, {text_inputs.shape[1]}'
|
||||||
|
|
||||||
mel_inputs = self.set_mel_padding(mel_inputs, wav_lengths)
|
mel_inputs = self.set_mel_padding(mel_inputs, wav_lengths)
|
||||||
|
if self.shuffle_conditioning:
|
||||||
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
|
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
|
||||||
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
|
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
|
||||||
|
|
||||||
|
@ -197,6 +199,7 @@ class UnifiedGptVoice(nn.Module):
|
||||||
"""
|
"""
|
||||||
assert self.max_symbols_per_phrase >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
|
assert self.max_symbols_per_phrase >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
|
||||||
|
|
||||||
|
if self.shuffle_conditioning:
|
||||||
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
|
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
|
||||||
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
|
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
|
||||||
|
|
||||||
|
@ -213,6 +216,7 @@ class UnifiedGptVoice(nn.Module):
|
||||||
assert self.max_mel_tokens >= mel_inputs.shape[1], f'{mel_inputs.shape[1]}'
|
assert self.max_mel_tokens >= mel_inputs.shape[1], f'{mel_inputs.shape[1]}'
|
||||||
|
|
||||||
mel_inputs = self.set_mel_padding(mel_inputs, wav_lengths)
|
mel_inputs = self.set_mel_padding(mel_inputs, wav_lengths)
|
||||||
|
if self.shuffle_conditioning:
|
||||||
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
|
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
|
||||||
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
|
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
|
||||||
|
|
||||||
|
@ -230,6 +234,7 @@ class UnifiedGptVoice(nn.Module):
|
||||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_paired_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
text_emb = self.text_embedding(text_inputs) + self.text_pos_paired_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
||||||
|
|
||||||
|
if self.shuffle_conditioning:
|
||||||
# Randomly permute the conditioning spectrogram, to destroy any structure present.
|
# Randomly permute the conditioning spectrogram, to destroy any structure present.
|
||||||
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
|
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
|
||||||
cond = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
|
cond = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user