forked from mrq/tortoise-tts
Update autoregressive to support type inputs
This commit is contained in:
parent
2dea4952d5
commit
4281b64517
|
@ -278,9 +278,10 @@ class MelEncoder(nn.Module):
|
||||||
class UnifiedVoice(nn.Module):
|
class UnifiedVoice(nn.Module):
|
||||||
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
|
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
|
||||||
mel_length_compression=1024, number_text_tokens=256,
|
mel_length_compression=1024, number_text_tokens=256,
|
||||||
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
|
start_text_token=None, number_mel_codes=8194, start_mel_token=8192,
|
||||||
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
|
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
|
||||||
checkpointing=True, average_conditioning_embeddings=False):
|
checkpointing=True, average_conditioning_embeddings=False,
|
||||||
|
types=1):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
layers: Number of layers in transformer stack.
|
layers: Number of layers in transformer stack.
|
||||||
|
@ -304,8 +305,8 @@ class UnifiedVoice(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.number_text_tokens = number_text_tokens
|
self.number_text_tokens = number_text_tokens
|
||||||
self.start_text_token = start_text_token
|
self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token
|
||||||
self.stop_text_token = stop_text_token
|
self.stop_text_token = 0
|
||||||
self.number_mel_codes = number_mel_codes
|
self.number_mel_codes = number_mel_codes
|
||||||
self.start_mel_token = start_mel_token
|
self.start_mel_token = start_mel_token
|
||||||
self.stop_mel_token = stop_mel_token
|
self.stop_mel_token = stop_mel_token
|
||||||
|
@ -318,7 +319,7 @@ class UnifiedVoice(nn.Module):
|
||||||
self.mel_length_compression = mel_length_compression
|
self.mel_length_compression = mel_length_compression
|
||||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
||||||
self.average_conditioning_embeddings = average_conditioning_embeddings
|
self.average_conditioning_embeddings = average_conditioning_embeddings
|
||||||
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim)
|
||||||
if use_mel_codes_as_input:
|
if use_mel_codes_as_input:
|
||||||
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
|
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
|
||||||
else:
|
else:
|
||||||
|
@ -333,7 +334,7 @@ class UnifiedVoice(nn.Module):
|
||||||
self.text_solo_embedding = 0
|
self.text_solo_embedding = 0
|
||||||
|
|
||||||
self.final_norm = nn.LayerNorm(model_dim)
|
self.final_norm = nn.LayerNorm(model_dim)
|
||||||
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
|
self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1)
|
||||||
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
|
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
|
||||||
|
|
||||||
# Initialize the embeddings per the GPT-2 scheme
|
# Initialize the embeddings per the GPT-2 scheme
|
||||||
|
@ -389,7 +390,7 @@ class UnifiedVoice(nn.Module):
|
||||||
else:
|
else:
|
||||||
return first_logits
|
return first_logits
|
||||||
|
|
||||||
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False,
|
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, raw_mels=None, return_attentions=False,
|
||||||
return_latent=False, clip_inputs=True):
|
return_latent=False, clip_inputs=True):
|
||||||
"""
|
"""
|
||||||
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
||||||
|
@ -406,6 +407,10 @@ class UnifiedVoice(nn.Module):
|
||||||
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
|
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
|
||||||
If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality.
|
If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality.
|
||||||
"""
|
"""
|
||||||
|
# Types are expressed by expanding the text embedding space.
|
||||||
|
if types is not None:
|
||||||
|
text_inputs = text_inputs * (1+types).unsqueeze(-1)
|
||||||
|
|
||||||
if clip_inputs:
|
if clip_inputs:
|
||||||
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
||||||
# chopping the inputs by the maximum actual length.
|
# chopping the inputs by the maximum actual length.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user