forked from mrq/DL-Art-School
Shamelessly nabbed from ae80992817
(if this is makes a big enough difference in training i'm going to cum)
This commit is contained in:
parent
0ee0f46596
commit
84c8196da5
|
@ -243,7 +243,8 @@ class UnifiedVoice(nn.Module):
|
||||||
mel_length_compression=1024, number_text_tokens=256,
|
mel_length_compression=1024, number_text_tokens=256,
|
||||||
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
|
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
|
||||||
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
|
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
|
||||||
checkpointing=True, average_conditioning_embeddings=False, freeze_everything_but_position_embeddings=False):
|
checkpointing=True, average_conditioning_embeddings=False, freeze_everything_but_position_embeddings=False,
|
||||||
|
tortoise_compat=True):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
layers: Number of layers in transformer stack.
|
layers: Number of layers in transformer stack.
|
||||||
|
@ -281,6 +282,7 @@ class UnifiedVoice(nn.Module):
|
||||||
self.mel_length_compression = mel_length_compression
|
self.mel_length_compression = mel_length_compression
|
||||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
||||||
self.average_conditioning_embeddings = average_conditioning_embeddings
|
self.average_conditioning_embeddings = average_conditioning_embeddings
|
||||||
|
self.tortoise_compat = tortoise_compat # credit to https://github.com/152334H/DL-Art-School/commit/ae80992817059acf6eef38a680efa5124cee570b
|
||||||
# nn.Embedding
|
# nn.Embedding
|
||||||
self.text_embedding = ml.Embedding(self.number_text_tokens, model_dim)
|
self.text_embedding = ml.Embedding(self.number_text_tokens, model_dim)
|
||||||
if use_mel_codes_as_input:
|
if use_mel_codes_as_input:
|
||||||
|
@ -301,6 +303,7 @@ class UnifiedVoice(nn.Module):
|
||||||
self.text_head = ml.Linear(model_dim, self.number_text_tokens)
|
self.text_head = ml.Linear(model_dim, self.number_text_tokens)
|
||||||
self.mel_head = ml.Linear(model_dim, self.number_mel_codes)
|
self.mel_head = ml.Linear(model_dim, self.number_mel_codes)
|
||||||
|
|
||||||
|
|
||||||
# Initialize the embeddings per the GPT-2 scheme
|
# Initialize the embeddings per the GPT-2 scheme
|
||||||
embeddings = [self.text_embedding]
|
embeddings = [self.text_embedding]
|
||||||
if use_mel_codes_as_input:
|
if use_mel_codes_as_input:
|
||||||
|
@ -386,6 +389,8 @@ class UnifiedVoice(nn.Module):
|
||||||
If return_attentions is specified, only logits are returned.
|
If return_attentions is specified, only logits are returned.
|
||||||
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
|
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
|
||||||
"""
|
"""
|
||||||
|
if self.tortoise_compat:
|
||||||
|
wav_lengths *= self.mel_length_compression
|
||||||
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
||||||
# chopping the inputs by the maximum actual length.
|
# chopping the inputs by the maximum actual length.
|
||||||
max_text_len = text_lengths.max()
|
max_text_len = text_lengths.max()
|
||||||
|
@ -414,14 +419,15 @@ class UnifiedVoice(nn.Module):
|
||||||
mel_emb = self.mel_embedding(mel_inp)
|
mel_emb = self.mel_embedding(mel_inp)
|
||||||
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
|
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
|
||||||
|
|
||||||
|
sub = -2 if self.tortoise_compat else -1
|
||||||
if text_first:
|
if text_first:
|
||||||
text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent)
|
text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent)
|
||||||
if return_latent:
|
if return_latent:
|
||||||
return mel_logits[:, :-1] # Despite the name, these are not logits.
|
return mel_logits[:, :-sub] # Despite the name, these are not logits.
|
||||||
else:
|
else:
|
||||||
mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent)
|
mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent)
|
||||||
if return_latent:
|
if return_latent:
|
||||||
return text_logits[:, :-1] # Despite the name, these are not logits
|
return text_logits[:, :-sub] # Despite the name, these are not logits
|
||||||
|
|
||||||
if return_attentions:
|
if return_attentions:
|
||||||
return mel_logits
|
return mel_logits
|
||||||
|
|
Loading…
Reference in New Issue
Block a user