forked from mrq/DL-Art-School
Add mel_encoder and solo embeddings to unified_voice
This commit is contained in:
parent
2165124f19
commit
963c6072bb
|
@ -250,7 +250,7 @@ class GptAsrHf2(nn.Module):
|
||||||
# This model uses its own positional embeddings, which helps discriminate between text and audio MELs.
|
# This model uses its own positional embeddings, which helps discriminate between text and audio MELs.
|
||||||
self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
|
self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
|
||||||
self.mel_pos_embedding = nn.Embedding(self.max_mel_frames, model_dim)
|
self.mel_pos_embedding = nn.Embedding(self.max_mel_frames, model_dim)
|
||||||
self.text_solo_embedding = nn.Parameter(torch.randn(1,1,512) * self.gpt.config.initializer_range, requires_grad=True)
|
self.text_solo_embedding = nn.Parameter(torch.randn(1,1,model_dim) * self.gpt.config.initializer_range, requires_grad=True)
|
||||||
|
|
||||||
# Head layers
|
# Head layers
|
||||||
self.final_norm = nn.LayerNorm(model_dim)
|
self.final_norm = nn.LayerNorm(model_dim)
|
||||||
|
|
|
@ -7,6 +7,7 @@ from transformers import GPT2Model, GPT2Config
|
||||||
|
|
||||||
from models.arch_util import AttentionBlock
|
from models.arch_util import AttentionBlock
|
||||||
from models.gpt_voice.gpt_asr_hf import GPT2InferenceModel
|
from models.gpt_voice.gpt_asr_hf import GPT2InferenceModel
|
||||||
|
from models.gpt_voice.gpt_asr_hf2 import ResBlock
|
||||||
from models.tacotron2.text import symbols
|
from models.tacotron2.text import symbols
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
@ -34,6 +35,30 @@ class ConditioningEncoder(nn.Module):
|
||||||
return h[:, :, 0]
|
return h[:, :, 0]
|
||||||
|
|
||||||
|
|
||||||
|
class MelEncoder(nn.Module):
|
||||||
|
def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=3, padding=1),
|
||||||
|
nn.Sequential(*[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]),
|
||||||
|
nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1),
|
||||||
|
nn.GroupNorm(channels//16, channels//2),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Sequential(*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]),
|
||||||
|
nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1),
|
||||||
|
nn.GroupNorm(channels//8, channels),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
|
||||||
|
)
|
||||||
|
self.reduction = 4
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for e in self.encoder:
|
||||||
|
x = e(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
def null_position_embeddings(range, dim):
|
def null_position_embeddings(range, dim):
|
||||||
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
||||||
|
|
||||||
|
@ -50,7 +75,7 @@ class UnifiedGptVoice(nn.Module):
|
||||||
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=120, max_mel_tokens=250, max_total_tokens=370, max_conditioning_inputs=3,
|
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=120, max_mel_tokens=250, max_total_tokens=370, max_conditioning_inputs=3,
|
||||||
checkpointing=True, mel_length_compression=1024, max_conditioning_length=60, number_text_tokens=256,
|
checkpointing=True, mel_length_compression=1024, max_conditioning_length=60, number_text_tokens=256,
|
||||||
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
|
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
|
||||||
stop_mel_token=8193, use_dedicated_position_embeddings_for_paired=True, shuffle_conditioning=True):
|
stop_mel_token=8193, shuffle_conditioning=True, train_solo_embeddings=False, use_mel_codes_as_input=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.number_text_tokens = number_text_tokens
|
self.number_text_tokens = number_text_tokens
|
||||||
|
@ -69,14 +94,8 @@ class UnifiedGptVoice(nn.Module):
|
||||||
self.mel_length_compression = mel_length_compression
|
self.mel_length_compression = mel_length_compression
|
||||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
||||||
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
||||||
self.text_pos_solo_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
|
self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
|
||||||
self.mel_pos_solo_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim)
|
self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim)
|
||||||
if use_dedicated_position_embeddings_for_paired:
|
|
||||||
self.mel_pos_paired_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim)
|
|
||||||
self.text_pos_paired_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
|
|
||||||
else:
|
|
||||||
self.mel_pos_paired_embedding = self.mel_pos_solo_embedding
|
|
||||||
self.text_pos_paired_embedding = self.text_pos_solo_embedding
|
|
||||||
seq_length = 2+self.max_total_tokens+self.max_conditioning_inputs
|
seq_length = 2+self.max_total_tokens+self.max_conditioning_inputs
|
||||||
self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes,
|
self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes,
|
||||||
n_positions=seq_length,
|
n_positions=seq_length,
|
||||||
|
@ -87,18 +106,26 @@ class UnifiedGptVoice(nn.Module):
|
||||||
gradient_checkpointing=checkpointing,
|
gradient_checkpointing=checkpointing,
|
||||||
use_cache=not checkpointing)
|
use_cache=not checkpointing)
|
||||||
self.gpt = GPT2Model(self.gpt_config)
|
self.gpt = GPT2Model(self.gpt_config)
|
||||||
|
if train_solo_embeddings:
|
||||||
|
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * self.gpt.config.initializer_range, requires_grad=True)
|
||||||
|
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * self.gpt.config.initializer_range, requires_grad=True)
|
||||||
|
else:
|
||||||
|
self.mel_solo_embedding = 0
|
||||||
|
self.text_solo_embedding = 0
|
||||||
# Override the built in positional embeddings
|
# Override the built in positional embeddings
|
||||||
del self.gpt.wpe
|
del self.gpt.wpe
|
||||||
self.gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
self.gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
||||||
|
|
||||||
|
if not use_mel_codes_as_input:
|
||||||
|
self.gpt.wte = MelEncoder(model_dim, resblocks_per_reduction=1)
|
||||||
|
|
||||||
self.final_norm = nn.LayerNorm(model_dim)
|
self.final_norm = nn.LayerNorm(model_dim)
|
||||||
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
|
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
|
||||||
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
|
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
|
||||||
self.max_conditioning_length = max_conditioning_length
|
self.max_conditioning_length = max_conditioning_length
|
||||||
|
|
||||||
# Initialize the embeddings per the GPT-2 scheme
|
# Initialize the embeddings per the GPT-2 scheme
|
||||||
for module in [self.text_embedding, self.text_pos_solo_embedding, self.text_pos_paired_embedding,
|
for module in [self.text_embedding, self.text_pos_embedding, self.mel_pos_embedding]:
|
||||||
self.mel_pos_solo_embedding, self.mel_pos_paired_embedding]:
|
|
||||||
module.weight.data.normal_(mean=0.0, std=self.gpt.config.initializer_range)
|
module.weight.data.normal_(mean=0.0, std=self.gpt.config.initializer_range)
|
||||||
if module.padding_idx is not None:
|
if module.padding_idx is not None:
|
||||||
module.weight.data[module.padding_idx].zero_()
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
@ -177,10 +204,10 @@ class UnifiedGptVoice(nn.Module):
|
||||||
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
|
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
|
||||||
|
|
||||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_paired_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
||||||
mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.start_mel_token, self.stop_mel_token)
|
mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.start_mel_token, self.stop_mel_token)
|
||||||
mel_emb = self.gpt.get_input_embeddings()(mel_inputs)
|
mel_emb = self.gpt.get_input_embeddings()(mel_inputs)
|
||||||
mel_emb = mel_emb + self.mel_pos_paired_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||||
if text_first:
|
if text_first:
|
||||||
text_logits, mel_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions)
|
text_logits, mel_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions)
|
||||||
else:
|
else:
|
||||||
|
@ -204,7 +231,7 @@ class UnifiedGptVoice(nn.Module):
|
||||||
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
|
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
|
||||||
|
|
||||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_solo_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
text_emb = self.text_embedding(text_inputs) + self.text_pos_solo_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) + self.text_solo_embedding
|
||||||
text_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head)
|
text_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head)
|
||||||
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
||||||
return loss_text.mean()
|
return loss_text.mean()
|
||||||
|
@ -222,7 +249,7 @@ class UnifiedGptVoice(nn.Module):
|
||||||
|
|
||||||
mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.start_mel_token, self.stop_mel_token)
|
mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.start_mel_token, self.stop_mel_token)
|
||||||
mel_emb = self.gpt.get_input_embeddings()(mel_inputs)
|
mel_emb = self.gpt.get_input_embeddings()(mel_inputs)
|
||||||
mel_emb = mel_emb + self.mel_pos_solo_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
mel_emb = mel_emb + self.mel_pos_solo_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) + self.mel_solo_embedding
|
||||||
mel_logits = self.get_logits(speech_conditioning_input, mel_emb, self.mel_head)
|
mel_logits = self.get_logits(speech_conditioning_input, mel_emb, self.mel_head)
|
||||||
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
||||||
return loss_mel.mean()
|
return loss_mel.mean()
|
||||||
|
@ -256,7 +283,7 @@ def register_unified_gpt_voice(opt_net, opt):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
gpt = UnifiedGptVoice(model_dim=256, heads=4, use_dedicated_position_embeddings_for_paired=False)
|
gpt = UnifiedGptVoice(model_dim=256, heads=4, train_solo_embeddings=True)
|
||||||
l = gpt(torch.randn(2, 80, 800),
|
l = gpt(torch.randn(2, 80, 800),
|
||||||
torch.randint(high=len(symbols), size=(2,80)),
|
torch.randint(high=len(symbols), size=(2,80)),
|
||||||
torch.randint(high=8192, size=(2,250)),
|
torch.randint(high=8192, size=(2,250)),
|
||||||
|
|
|
@ -286,7 +286,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mass_hf2.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mel_encoder_pred_codes.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user