Update autoregressive to support type inputs

This commit is contained in:
James Betker 2022-04-18 09:22:27 -06:00
parent 713281e376
commit 76c30fe344

View File

@ -278,9 +278,10 @@ class MelEncoder(nn.Module):
class UnifiedVoice(nn.Module): class UnifiedVoice(nn.Module):
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1, def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
mel_length_compression=1024, number_text_tokens=256, mel_length_compression=1024, number_text_tokens=256,
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192, start_text_token=None, number_mel_codes=8194, start_mel_token=8192,
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True, stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
checkpointing=True, average_conditioning_embeddings=False): checkpointing=True, average_conditioning_embeddings=False,
types=1):
""" """
Args: Args:
layers: Number of layers in transformer stack. layers: Number of layers in transformer stack.
@ -304,8 +305,8 @@ class UnifiedVoice(nn.Module):
super().__init__() super().__init__()
self.number_text_tokens = number_text_tokens self.number_text_tokens = number_text_tokens
self.start_text_token = start_text_token self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token
self.stop_text_token = stop_text_token self.stop_text_token = 0
self.number_mel_codes = number_mel_codes self.number_mel_codes = number_mel_codes
self.start_mel_token = start_mel_token self.start_mel_token = start_mel_token
self.stop_mel_token = stop_mel_token self.stop_mel_token = stop_mel_token
@ -318,7 +319,7 @@ class UnifiedVoice(nn.Module):
self.mel_length_compression = mel_length_compression self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.average_conditioning_embeddings = average_conditioning_embeddings self.average_conditioning_embeddings = average_conditioning_embeddings
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim) self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim)
if use_mel_codes_as_input: if use_mel_codes_as_input:
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim) self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
else: else:
@ -333,7 +334,7 @@ class UnifiedVoice(nn.Module):
self.text_solo_embedding = 0 self.text_solo_embedding = 0
self.final_norm = nn.LayerNorm(model_dim) self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.number_text_tokens) self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1)
self.mel_head = nn.Linear(model_dim, self.number_mel_codes) self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
# Initialize the embeddings per the GPT-2 scheme # Initialize the embeddings per the GPT-2 scheme
@ -389,7 +390,7 @@ class UnifiedVoice(nn.Module):
else: else:
return first_logits return first_logits
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False, def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, raw_mels=None, return_attentions=False,
return_latent=False, clip_inputs=True): return_latent=False, clip_inputs=True):
""" """
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
@ -406,6 +407,10 @@ class UnifiedVoice(nn.Module):
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned. If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality. If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality.
""" """
# Types are expressed by expanding the text embedding space.
if types is not None:
text_inputs = text_inputs * (1+types).unsqueeze(-1)
if clip_inputs: if clip_inputs:
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
# chopping the inputs by the maximum actual length. # chopping the inputs by the maximum actual length.