unified_voice improvements

- Rename max_symbols_per_phrase to max_text_tokens
- Remove max_total_tokens (no longer necessary)
- Fix integration with MelEncoder
This commit is contained in:
James Betker 2022-01-05 17:03:53 -07:00
parent 50d267ab1a
commit c584ba05ee

View File

@ -56,7 +56,7 @@ class MelEncoder(nn.Module):
def forward(self, x): def forward(self, x):
for e in self.encoder: for e in self.encoder:
x = e(x) x = e(x)
return x return x.permute(0,2,1)
def null_position_embeddings(range, dim): def null_position_embeddings(range, dim):
@ -72,10 +72,32 @@ class UnifiedGptVoice(nn.Module):
- Voice conditioned on text - Voice conditioned on text
""" """
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=120, max_mel_tokens=250, max_total_tokens=370, max_conditioning_inputs=3, def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
checkpointing=True, mel_length_compression=1024, max_conditioning_length=60, number_text_tokens=256, max_conditioning_length=60, shuffle_conditioning=True, mel_length_compression=1024, number_text_tokens=256,
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192, start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
stop_mel_token=8193, shuffle_conditioning=True, train_solo_embeddings=False, use_mel_codes_as_input=True): stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
checkpointing=True):
"""
Args:
layers: Number of layers in transformer stack.
model_dim: Operating dimensions of the transformer
heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64
max_text_tokens: Maximum number of text tokens that will be encountered by model.
max_mel_tokens: Maximum number of MEL tokens that will be encountered by model.
max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
max_conditioning_length: Maximum length of conditioning input. Only needed if shuffle_conditioning=True
shuffle_conditioning: Whether or not the conditioning inputs will be shuffled across the sequence dimension. Useful if you want to provide the same input as conditioning and mel_codes.
mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
number_text_tokens:
start_text_token:
stop_text_token:
number_mel_codes:
start_mel_token:
stop_mel_token:
train_solo_embeddings:
use_mel_codes_as_input:
checkpointing:
"""
super().__init__() super().__init__()
self.number_text_tokens = number_text_tokens self.number_text_tokens = number_text_tokens
@ -87,16 +109,15 @@ class UnifiedGptVoice(nn.Module):
self.shuffle_conditioning = shuffle_conditioning self.shuffle_conditioning = shuffle_conditioning
self.max_mel_tokens = max_mel_tokens self.max_mel_tokens = max_mel_tokens
self.max_symbols_per_phrase = max_symbols_per_phrase self.max_text_tokens = max_text_tokens
self.max_total_tokens = max_total_tokens
self.model_dim = model_dim self.model_dim = model_dim
self.max_conditioning_inputs = max_conditioning_inputs self.max_conditioning_inputs = max_conditioning_inputs
self.mel_length_compression = mel_length_compression self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim) self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim) self.text_pos_embedding = nn.Embedding(self.max_text_tokens + 1, model_dim)
self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim) self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim)
seq_length = 2+self.max_total_tokens+self.max_conditioning_inputs seq_length = 2+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs
self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes, self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes,
n_positions=seq_length, n_positions=seq_length,
n_ctx=seq_length, n_ctx=seq_length,
@ -184,7 +205,7 @@ class UnifiedGptVoice(nn.Module):
else: else:
return first_logits return first_logits
def forward(self, speech_conditioning_input, text_inputs, mel_inputs, wav_lengths, text_first=True, return_attentions=False): def forward(self, speech_conditioning_input, text_inputs, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False):
""" """
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
(actuated by `text_first`). (actuated by `text_first`).
@ -193,20 +214,24 @@ class UnifiedGptVoice(nn.Module):
text_inputs: long tensor, (b,t) text_inputs: long tensor, (b,t)
mel_inputs: long tensor, (b,m) mel_inputs: long tensor, (b,m)
wav_lengths: long tensor, (b,) wav_lengths: long tensor, (b,)
raw_mels: MEL float tensor (b,80,s)
""" """
assert self.max_mel_tokens >= mel_inputs.shape[1], f'{mel_inputs.shape[1]}' assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}'
assert self.max_symbols_per_phrase >= text_inputs.shape[1], f'{text_inputs.shape[1]}' assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
assert self.max_total_tokens >= mel_inputs.shape[1] + text_inputs.shape[1], f'{mel_inputs.shape[1]}, {text_inputs.shape[1]}'
mel_inputs = self.set_mel_padding(mel_inputs, wav_lengths) mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
if self.shuffle_conditioning: if self.shuffle_conditioning:
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.start_mel_token, self.stop_mel_token) mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
mel_emb = self.gpt.get_input_embeddings()(mel_inputs) if raw_mels is not None:
mel_inp = F.pad(raw_mels, (0, 4))
else:
mel_inp = mel_codes
mel_emb = self.gpt.get_input_embeddings()(mel_inp)
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
if text_first: if text_first:
text_logits, mel_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions) text_logits, mel_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions)
@ -224,7 +249,7 @@ class UnifiedGptVoice(nn.Module):
Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the
model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided). model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided).
""" """
assert self.max_symbols_per_phrase >= text_inputs.shape[1], f'{text_inputs.shape[1]}' assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
if self.shuffle_conditioning: if self.shuffle_conditioning:
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
@ -236,19 +261,23 @@ class UnifiedGptVoice(nn.Module):
loss_text = F.cross_entropy(text_logits, text_targets.long()) loss_text = F.cross_entropy(text_logits, text_targets.long())
return loss_text.mean() return loss_text.mean()
def speech_forward(self, speech_conditioning_input, mel_inputs, wav_lengths): def speech_forward(self, speech_conditioning_input, mel_codes, wav_lengths, raw_mels=None):
""" """
Performs autoregressive modeling on only speech data. Performs autoregressive modeling on only speech data.
""" """
assert self.max_mel_tokens >= mel_inputs.shape[1], f'{mel_inputs.shape[1]}' assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}'
mel_inputs = self.set_mel_padding(mel_inputs, wav_lengths) mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
if self.shuffle_conditioning: if self.shuffle_conditioning:
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.start_mel_token, self.stop_mel_token) mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
mel_emb = self.gpt.get_input_embeddings()(mel_inputs) if raw_mels is not None:
mel_inp = F.pad(raw_mels, (0, 4))
else:
mel_inp = mel_codes
mel_emb = self.gpt.get_input_embeddings()(mel_inp)
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) + self.mel_solo_embedding mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) + self.mel_solo_embedding
mel_logits = self.get_logits(speech_conditioning_input, mel_emb, self.mel_head) mel_logits = self.get_logits(speech_conditioning_input, mel_emb, self.mel_head)
loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
@ -283,8 +312,10 @@ def register_unified_gpt_voice(opt_net, opt):
if __name__ == '__main__': if __name__ == '__main__':
gpt = UnifiedGptVoice(model_dim=256, heads=4, train_solo_embeddings=True) gpt = UnifiedGptVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=False)
l = gpt(torch.randn(2, 80, 800), l = gpt(torch.randn(2, 80, 800),
torch.randint(high=len(symbols), size=(2,80)), torch.randint(high=len(symbols), size=(2,80)),
torch.randint(high=8192, size=(2,250)), torch.randint(high=8192, size=(2,250)),
torch.tensor([150*256,195*256])) torch.tensor([150*256,195*256]),
raw_mels=torch.randn(2,80,1000))
gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)))