forked from mrq/DL-Art-School
unified_voice improvements
- Rename max_symbols_per_phrase to max_text_tokens - Remove max_total_tokens (no longer necessary) - Fix integration with MelEncoder
This commit is contained in:
parent
50d267ab1a
commit
c584ba05ee
|
@ -56,7 +56,7 @@ class MelEncoder(nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
for e in self.encoder:
|
for e in self.encoder:
|
||||||
x = e(x)
|
x = e(x)
|
||||||
return x
|
return x.permute(0,2,1)
|
||||||
|
|
||||||
|
|
||||||
def null_position_embeddings(range, dim):
|
def null_position_embeddings(range, dim):
|
||||||
|
@ -72,10 +72,32 @@ class UnifiedGptVoice(nn.Module):
|
||||||
- Voice conditioned on text
|
- Voice conditioned on text
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=120, max_mel_tokens=250, max_total_tokens=370, max_conditioning_inputs=3,
|
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
|
||||||
checkpointing=True, mel_length_compression=1024, max_conditioning_length=60, number_text_tokens=256,
|
max_conditioning_length=60, shuffle_conditioning=True, mel_length_compression=1024, number_text_tokens=256,
|
||||||
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
|
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
|
||||||
stop_mel_token=8193, shuffle_conditioning=True, train_solo_embeddings=False, use_mel_codes_as_input=True):
|
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
|
||||||
|
checkpointing=True):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
layers: Number of layers in transformer stack.
|
||||||
|
model_dim: Operating dimensions of the transformer
|
||||||
|
heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64
|
||||||
|
max_text_tokens: Maximum number of text tokens that will be encountered by model.
|
||||||
|
max_mel_tokens: Maximum number of MEL tokens that will be encountered by model.
|
||||||
|
max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
|
||||||
|
max_conditioning_length: Maximum length of conditioning input. Only needed if shuffle_conditioning=True
|
||||||
|
shuffle_conditioning: Whether or not the conditioning inputs will be shuffled across the sequence dimension. Useful if you want to provide the same input as conditioning and mel_codes.
|
||||||
|
mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
|
||||||
|
number_text_tokens:
|
||||||
|
start_text_token:
|
||||||
|
stop_text_token:
|
||||||
|
number_mel_codes:
|
||||||
|
start_mel_token:
|
||||||
|
stop_mel_token:
|
||||||
|
train_solo_embeddings:
|
||||||
|
use_mel_codes_as_input:
|
||||||
|
checkpointing:
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.number_text_tokens = number_text_tokens
|
self.number_text_tokens = number_text_tokens
|
||||||
|
@ -87,16 +109,15 @@ class UnifiedGptVoice(nn.Module):
|
||||||
self.shuffle_conditioning = shuffle_conditioning
|
self.shuffle_conditioning = shuffle_conditioning
|
||||||
|
|
||||||
self.max_mel_tokens = max_mel_tokens
|
self.max_mel_tokens = max_mel_tokens
|
||||||
self.max_symbols_per_phrase = max_symbols_per_phrase
|
self.max_text_tokens = max_text_tokens
|
||||||
self.max_total_tokens = max_total_tokens
|
|
||||||
self.model_dim = model_dim
|
self.model_dim = model_dim
|
||||||
self.max_conditioning_inputs = max_conditioning_inputs
|
self.max_conditioning_inputs = max_conditioning_inputs
|
||||||
self.mel_length_compression = mel_length_compression
|
self.mel_length_compression = mel_length_compression
|
||||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
||||||
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
||||||
self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
|
self.text_pos_embedding = nn.Embedding(self.max_text_tokens + 1, model_dim)
|
||||||
self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim)
|
self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim)
|
||||||
seq_length = 2+self.max_total_tokens+self.max_conditioning_inputs
|
seq_length = 2+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs
|
||||||
self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes,
|
self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes,
|
||||||
n_positions=seq_length,
|
n_positions=seq_length,
|
||||||
n_ctx=seq_length,
|
n_ctx=seq_length,
|
||||||
|
@ -184,7 +205,7 @@ class UnifiedGptVoice(nn.Module):
|
||||||
else:
|
else:
|
||||||
return first_logits
|
return first_logits
|
||||||
|
|
||||||
def forward(self, speech_conditioning_input, text_inputs, mel_inputs, wav_lengths, text_first=True, return_attentions=False):
|
def forward(self, speech_conditioning_input, text_inputs, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False):
|
||||||
"""
|
"""
|
||||||
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
||||||
(actuated by `text_first`).
|
(actuated by `text_first`).
|
||||||
|
@ -193,20 +214,24 @@ class UnifiedGptVoice(nn.Module):
|
||||||
text_inputs: long tensor, (b,t)
|
text_inputs: long tensor, (b,t)
|
||||||
mel_inputs: long tensor, (b,m)
|
mel_inputs: long tensor, (b,m)
|
||||||
wav_lengths: long tensor, (b,)
|
wav_lengths: long tensor, (b,)
|
||||||
|
raw_mels: MEL float tensor (b,80,s)
|
||||||
"""
|
"""
|
||||||
assert self.max_mel_tokens >= mel_inputs.shape[1], f'{mel_inputs.shape[1]}'
|
assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}'
|
||||||
assert self.max_symbols_per_phrase >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
|
assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
|
||||||
assert self.max_total_tokens >= mel_inputs.shape[1] + text_inputs.shape[1], f'{mel_inputs.shape[1]}, {text_inputs.shape[1]}'
|
|
||||||
|
|
||||||
mel_inputs = self.set_mel_padding(mel_inputs, wav_lengths)
|
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
|
||||||
if self.shuffle_conditioning:
|
if self.shuffle_conditioning:
|
||||||
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
|
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
|
||||||
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
|
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
|
||||||
|
|
||||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
||||||
mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.start_mel_token, self.stop_mel_token)
|
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
|
||||||
mel_emb = self.gpt.get_input_embeddings()(mel_inputs)
|
if raw_mels is not None:
|
||||||
|
mel_inp = F.pad(raw_mels, (0, 4))
|
||||||
|
else:
|
||||||
|
mel_inp = mel_codes
|
||||||
|
mel_emb = self.gpt.get_input_embeddings()(mel_inp)
|
||||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||||
if text_first:
|
if text_first:
|
||||||
text_logits, mel_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions)
|
text_logits, mel_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions)
|
||||||
|
@ -224,7 +249,7 @@ class UnifiedGptVoice(nn.Module):
|
||||||
Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the
|
Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the
|
||||||
model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided).
|
model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided).
|
||||||
"""
|
"""
|
||||||
assert self.max_symbols_per_phrase >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
|
assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
|
||||||
|
|
||||||
if self.shuffle_conditioning:
|
if self.shuffle_conditioning:
|
||||||
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
|
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
|
||||||
|
@ -236,19 +261,23 @@ class UnifiedGptVoice(nn.Module):
|
||||||
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
||||||
return loss_text.mean()
|
return loss_text.mean()
|
||||||
|
|
||||||
def speech_forward(self, speech_conditioning_input, mel_inputs, wav_lengths):
|
def speech_forward(self, speech_conditioning_input, mel_codes, wav_lengths, raw_mels=None):
|
||||||
"""
|
"""
|
||||||
Performs autoregressive modeling on only speech data.
|
Performs autoregressive modeling on only speech data.
|
||||||
"""
|
"""
|
||||||
assert self.max_mel_tokens >= mel_inputs.shape[1], f'{mel_inputs.shape[1]}'
|
assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}'
|
||||||
|
|
||||||
mel_inputs = self.set_mel_padding(mel_inputs, wav_lengths)
|
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
|
||||||
if self.shuffle_conditioning:
|
if self.shuffle_conditioning:
|
||||||
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
|
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
|
||||||
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
|
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
|
||||||
|
|
||||||
mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.start_mel_token, self.stop_mel_token)
|
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
|
||||||
mel_emb = self.gpt.get_input_embeddings()(mel_inputs)
|
if raw_mels is not None:
|
||||||
|
mel_inp = F.pad(raw_mels, (0, 4))
|
||||||
|
else:
|
||||||
|
mel_inp = mel_codes
|
||||||
|
mel_emb = self.gpt.get_input_embeddings()(mel_inp)
|
||||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) + self.mel_solo_embedding
|
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) + self.mel_solo_embedding
|
||||||
mel_logits = self.get_logits(speech_conditioning_input, mel_emb, self.mel_head)
|
mel_logits = self.get_logits(speech_conditioning_input, mel_emb, self.mel_head)
|
||||||
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
||||||
|
@ -283,8 +312,10 @@ def register_unified_gpt_voice(opt_net, opt):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
gpt = UnifiedGptVoice(model_dim=256, heads=4, train_solo_embeddings=True)
|
gpt = UnifiedGptVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=False)
|
||||||
l = gpt(torch.randn(2, 80, 800),
|
l = gpt(torch.randn(2, 80, 800),
|
||||||
torch.randint(high=len(symbols), size=(2,80)),
|
torch.randint(high=len(symbols), size=(2,80)),
|
||||||
torch.randint(high=8192, size=(2,250)),
|
torch.randint(high=8192, size=(2,250)),
|
||||||
torch.tensor([150*256,195*256]))
|
torch.tensor([150*256,195*256]),
|
||||||
|
raw_mels=torch.randn(2,80,1000))
|
||||||
|
gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user