dataset improvements and fix to unified_voice_Bilevel
This commit is contained in:
parent
eda753e776
commit
bbacffb790
|
@ -135,8 +135,8 @@ class TextWavLoader(torch.utils.data.Dataset):
|
||||||
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
|
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
|
||||||
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
|
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
|
||||||
# It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
|
# It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
|
||||||
#if wav is not None:
|
if self.debug_failures:
|
||||||
# print(f"Exception {index} wav_len:{wav.shape[-1]} text_len:{tseq.shape[0]} fname: {path}")
|
print(f"error loading {path}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}")
|
||||||
rv = random.randint(0,len(self)-1)
|
rv = random.randint(0,len(self)-1)
|
||||||
return self[rv]
|
return self[rv]
|
||||||
orig_output = wav.shape[-1]
|
orig_output = wav.shape[-1]
|
||||||
|
|
|
@ -85,7 +85,10 @@ def load_similar_clips(path, sample_length, sample_rate, n=3, include_self=True,
|
||||||
rand_start = random.randint(0, gap)
|
rand_start = random.randint(0, gap)
|
||||||
rel_clip = rel_clip[:, rand_start:rand_start+sample_length]
|
rel_clip = rel_clip[:, rand_start:rand_start+sample_length]
|
||||||
related_clips.append(rel_clip)
|
related_clips.append(rel_clip)
|
||||||
return torch.stack(related_clips, dim=0)
|
if n > 1:
|
||||||
|
return torch.stack(related_clips, dim=0)
|
||||||
|
else:
|
||||||
|
return related_clips[0]
|
||||||
|
|
||||||
|
|
||||||
class UnsupervisedAudioDataset(torch.utils.data.Dataset):
|
class UnsupervisedAudioDataset(torch.utils.data.Dataset):
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import functools
|
import functools
|
||||||
|
from math import log
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -12,6 +13,10 @@ from trainer.networks import register_model
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
|
def null_position_embeddings(range, dim):
|
||||||
|
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
||||||
|
|
||||||
|
|
||||||
class ConditioningEncoder(nn.Module):
|
class ConditioningEncoder(nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
spec_dim,
|
spec_dim,
|
||||||
|
@ -34,8 +39,22 @@ class ConditioningEncoder(nn.Module):
|
||||||
return h[:, :, 0]
|
return h[:, :, 0]
|
||||||
|
|
||||||
|
|
||||||
def null_position_embeddings(range, dim):
|
class TopEncoder(nn.Module):
|
||||||
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
def __init__(self, layers, dim, heads, do_checkpointing=False, dim_reduction=16):
|
||||||
|
self.init = nn.Conv1d(dim, dim, kernel_size=1)
|
||||||
|
reduction_layers = []
|
||||||
|
for j in range(int(log(dim_reduction, 2))):
|
||||||
|
reduction_layers.append(AttentionBlock(dim, heads, do_checkpoint=do_checkpointing))
|
||||||
|
reduction_layers.append(nn.Conv1d(dim, dim, kernel_size=3, padding=1, stride=2))
|
||||||
|
self.reduction_layers = nn.Sequential(*reduction_layers)
|
||||||
|
actual_layers = [AttentionBlock(dim, heads, do_checkpoint=do_checkpointing) for _ in range(layers)]
|
||||||
|
self.actual_layers = nn.Sequential(*actual_layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h = self.init(x)
|
||||||
|
h = self.reduction_layers(h)
|
||||||
|
h = self.actual_layers(h)
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
class UnifiedGptVoice(nn.Module):
|
class UnifiedGptVoice(nn.Module):
|
||||||
|
@ -47,7 +66,8 @@ class UnifiedGptVoice(nn.Module):
|
||||||
- Voice conditioned on text
|
- Voice conditioned on text
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=120, max_mel_tokens=250, max_total_tokens=370, max_conditioning_inputs=3,
|
def __init__(self, top_encoder_layers=4, top_layers=8, bottom_layers=8, top_dim_reduction=16, model_dim=512, heads=8,
|
||||||
|
max_symbols_per_phrase=120, max_mel_tokens=250, max_total_tokens=370, max_conditioning_inputs=3,
|
||||||
checkpointing=True, mel_length_compression=1024, max_conditioning_length=60, number_text_tokens=256,
|
checkpointing=True, mel_length_compression=1024, max_conditioning_length=60, number_text_tokens=256,
|
||||||
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
|
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
|
||||||
stop_mel_token=8193):
|
stop_mel_token=8193):
|
||||||
|
@ -73,18 +93,35 @@ class UnifiedGptVoice(nn.Module):
|
||||||
self.mel_pos_solo_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim)
|
self.mel_pos_solo_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim)
|
||||||
self.mel_pos_paired_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim)
|
self.mel_pos_paired_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim)
|
||||||
seq_length = 2+self.max_total_tokens+self.max_conditioning_inputs
|
seq_length = 2+self.max_total_tokens+self.max_conditioning_inputs
|
||||||
self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes,
|
|
||||||
n_positions=seq_length,
|
self.top_encoder = TopEncoder(top_encoder_layers, model_dim, heads, do_checkpointing=checkpointing,
|
||||||
n_ctx=seq_length,
|
dim_reduction=top_dim_reduction)
|
||||||
n_embd=model_dim,
|
self.top_gpt_config = GPT2Config(vocab_size=1,
|
||||||
n_layer=layers,
|
n_positions=seq_length // top_dim_reduction,
|
||||||
n_head=heads,
|
n_ctx=seq_length // top_dim_reduction,
|
||||||
gradient_checkpointing=checkpointing,
|
n_embd=model_dim,
|
||||||
use_cache=not checkpointing)
|
n_layer=top_layers,
|
||||||
self.gpt = GPT2Model(self.gpt_config)
|
n_head=heads,
|
||||||
|
gradient_checkpointing=checkpointing,
|
||||||
|
use_cache=not checkpointing)
|
||||||
|
self.top_gpt = GPT2Model(self.top_gpt_config)
|
||||||
|
del self.top_gpt.wte
|
||||||
|
self.top_gpt_start_embedding = nn.Parameter(torch.randn(1,1,model_dim)*self.top_gpt_config.initializer_range,
|
||||||
|
requires_grad=True)
|
||||||
|
self.top_dim_reduction = top_dim_reduction
|
||||||
|
|
||||||
|
self.bottom_gpt_config = GPT2Config(vocab_size=self.number_mel_codes,
|
||||||
|
n_positions=seq_length,
|
||||||
|
n_ctx=seq_length,
|
||||||
|
n_embd=model_dim,
|
||||||
|
n_layer=bottom_layers,
|
||||||
|
n_head=heads,
|
||||||
|
gradient_checkpointing=checkpointing,
|
||||||
|
use_cache=not checkpointing)
|
||||||
|
self.bottom_gpt = GPT2Model(self.bottom_gpt_config)
|
||||||
# Override the built in positional embeddings
|
# Override the built in positional embeddings
|
||||||
del self.gpt.wpe
|
del self.bottom_gpt.wpe
|
||||||
self.gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
self.bottom_gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
||||||
|
|
||||||
self.final_norm = nn.LayerNorm(model_dim)
|
self.final_norm = nn.LayerNorm(model_dim)
|
||||||
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
|
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
|
||||||
|
@ -94,7 +131,7 @@ class UnifiedGptVoice(nn.Module):
|
||||||
# Initialize the embeddings per the GPT-2 scheme
|
# Initialize the embeddings per the GPT-2 scheme
|
||||||
for module in [self.text_embedding, self.text_pos_solo_embedding, self.text_pos_paired_embedding,
|
for module in [self.text_embedding, self.text_pos_solo_embedding, self.text_pos_paired_embedding,
|
||||||
self.mel_pos_solo_embedding, self.mel_pos_paired_embedding]:
|
self.mel_pos_solo_embedding, self.mel_pos_paired_embedding]:
|
||||||
module.weight.data.normal_(mean=0.0, std=self.gpt.config.initializer_range)
|
module.weight.data.normal_(mean=0.0, std=self.bottom_gpt.config.initializer_range)
|
||||||
if module.padding_idx is not None:
|
if module.padding_idx is not None:
|
||||||
module.weight.data[module.padding_idx].zero_()
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
|
||||||
|
@ -129,13 +166,34 @@ class UnifiedGptVoice(nn.Module):
|
||||||
cond_input = cond_input[:,:,:self.max_conditioning_length]
|
cond_input = cond_input[:,:,:self.max_conditioning_length]
|
||||||
return cond_input
|
return cond_input
|
||||||
|
|
||||||
|
|
||||||
|
def get_top_embeddings(self, embedded_input):
|
||||||
|
true_embeddings = self.top_encoder(embedded_input)
|
||||||
|
inputs = torch.cat([self.top_gpt_start_embedding, true_embeddings[:,:-1]], dim=1)
|
||||||
|
top_pred = self.top_gpt(inputs_embeds=inputs, return_dict=True)
|
||||||
|
return top_pred.last_hidden_state, true_embeddings
|
||||||
|
|
||||||
|
|
||||||
|
def inject_top_embeddings(self, embedded_input, probability_of_true_top_embedding=.5):
|
||||||
|
pred, true = self.get_top_embeddings(embedded_input)
|
||||||
|
rand = torch.bernoulli(torch.full((1,embedded_input.shape[1]),
|
||||||
|
fill_value=probability_of_true_top_embedding)).to(embedded_input.device)
|
||||||
|
mix = pred * rand + true * (not rand)
|
||||||
|
embs = torch.chunk(embedded_input, self.top_dim_reduction, dim=1)
|
||||||
|
assert len(embs) == mix.shape[1]
|
||||||
|
rejoin = []
|
||||||
|
for i, emb in enumerate(embs):
|
||||||
|
rejoin.append(torch.cat([mix[i], emb]), dim=1)
|
||||||
|
return torch.cat(rejoin, dim=1)
|
||||||
|
|
||||||
|
|
||||||
def get_logits(self, speech_conditioning_input, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False):
|
def get_logits(self, speech_conditioning_input, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False):
|
||||||
if second_inputs is not None:
|
if second_inputs is not None:
|
||||||
emb = torch.cat([speech_conditioning_input, first_inputs, second_inputs], dim=1)
|
emb = torch.cat([speech_conditioning_input, first_inputs, second_inputs], dim=1)
|
||||||
else:
|
else:
|
||||||
emb = torch.cat([speech_conditioning_input, first_inputs], dim=1)
|
emb = torch.cat([speech_conditioning_input, first_inputs], dim=1)
|
||||||
|
|
||||||
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
|
gpt_out = self.bottom_gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
|
||||||
if get_attns:
|
if get_attns:
|
||||||
return gpt_out.attentions
|
return gpt_out.attentions
|
||||||
|
|
||||||
|
@ -173,8 +231,9 @@ class UnifiedGptVoice(nn.Module):
|
||||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_paired_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
text_emb = self.text_embedding(text_inputs) + self.text_pos_paired_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
||||||
mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.start_mel_token, self.stop_mel_token)
|
mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.start_mel_token, self.stop_mel_token)
|
||||||
mel_emb = self.gpt.get_input_embeddings()(mel_inputs)
|
mel_emb = self.bottom_gpt.get_input_embeddings()(mel_inputs)
|
||||||
mel_emb = mel_emb + self.mel_pos_paired_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
mel_emb = mel_emb + self.mel_pos_paired_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||||
|
|
||||||
if text_first:
|
if text_first:
|
||||||
text_logits, mel_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions)
|
text_logits, mel_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions)
|
||||||
else:
|
else:
|
||||||
|
@ -213,7 +272,7 @@ class UnifiedGptVoice(nn.Module):
|
||||||
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
|
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
|
||||||
|
|
||||||
mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.start_mel_token, self.stop_mel_token)
|
mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.start_mel_token, self.stop_mel_token)
|
||||||
mel_emb = self.gpt.get_input_embeddings()(mel_inputs)
|
mel_emb = self.bottom_gpt.get_input_embeddings()(mel_inputs)
|
||||||
mel_emb = mel_emb + self.mel_pos_solo_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
mel_emb = mel_emb + self.mel_pos_solo_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||||
mel_logits = self.get_logits(speech_conditioning_input, mel_emb, self.mel_head)
|
mel_logits = self.get_logits(speech_conditioning_input, mel_emb, self.mel_head)
|
||||||
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
||||||
|
@ -221,7 +280,7 @@ class UnifiedGptVoice(nn.Module):
|
||||||
|
|
||||||
def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs):
|
def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs):
|
||||||
if not hasattr(self, 'inference_model'):
|
if not hasattr(self, 'inference_model'):
|
||||||
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.mel_pos_paired_embedding, self.final_norm, self.mel_head)
|
self.inference_model = GPT2InferenceModel(self.bottom_gpt_config, self.bottom_gpt, self.mel_pos_paired_embedding, self.final_norm, self.mel_head)
|
||||||
|
|
||||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_paired_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
text_emb = self.text_embedding(text_inputs) + self.text_pos_paired_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
||||||
|
@ -237,12 +296,12 @@ class UnifiedGptVoice(nn.Module):
|
||||||
fake_inputs[:,-1] = self.start_mel_token
|
fake_inputs[:,-1] = self.start_mel_token
|
||||||
|
|
||||||
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
|
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
|
||||||
max_length=self.gpt_config.n_positions, **hf_generate_kwargs)
|
max_length=self.bottom_gpt_config.n_positions, **hf_generate_kwargs)
|
||||||
return gen[:, fake_inputs.shape[1]:]
|
return gen[:, fake_inputs.shape[1]:]
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_unified_gpt_voice(opt_net, opt):
|
def register_unified_gpt_voice_bilevel(opt_net, opt):
|
||||||
return UnifiedGptVoice(**opt_get(opt_net, ['kwargs'], {}))
|
return UnifiedGptVoice(**opt_get(opt_net, ['kwargs'], {}))
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user