This commit is contained in:
James Betker 2022-05-13 17:57:47 -06:00
parent eb64d18075
commit ee218ab9b7

View File

@ -238,7 +238,8 @@ class MelEncoder(nn.Module):
class UnifiedVoice(nn.Module): class UnifiedVoice(nn.Module):
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1, def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
mel_length_compression=1024, number_text_tokens=256, number_mel_codes=8194, start_mel_token=8192, mel_length_compression=1024, number_text_tokens=256, number_mel_codes=8194, start_mel_token=8192,
stop_mel_token=8193, start_text_token=None, number_aligned_text_codes=256, checkpointing=True, types=1): stop_mel_token=8193, start_text_token=None, number_aligned_text_codes=256, checkpointing=True, types=1,
freeze_for_aligned_codes=False,):
""" """
Args: Args:
layers: Number of layers in transformer stack. layers: Number of layers in transformer stack.
@ -278,13 +279,21 @@ class UnifiedVoice(nn.Module):
self.final_norm = nn.LayerNorm(model_dim) self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1) self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1)
self.mel_head = nn.Linear(model_dim, self.number_mel_codes) self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
self.aligned_head = nn.Linear(model_dim, self.number_aligned_text_codes) self.aligned_head = nn.Linear(model_dim, number_aligned_text_codes)
# Initialize the embeddings per the GPT-2 scheme # Initialize the embeddings per the GPT-2 scheme
embeddings = [self.text_embedding, self.mel_embedding] embeddings = [self.text_embedding, self.mel_embedding]
for module in embeddings: for module in embeddings:
module.weight.data.normal_(mean=0.0, std=.02) module.weight.data.normal_(mean=0.0, std=.02)
if freeze_for_aligned_codes:
for p in self.parameters():
p.DO_NOT_TRAIN = True
p.requires_grad = False
for p in self.aligned_head.parameters():
del p.DO_NOT_TRAIN
p.requires_grad = True
def get_grad_norm_parameter_groups(self): def get_grad_norm_parameter_groups(self):
return { return {
'conditioning_encoder': list(self.conditioning_encoder.parameters()), 'conditioning_encoder': list(self.conditioning_encoder.parameters()),
@ -363,18 +372,18 @@ class UnifiedVoice(nn.Module):
if types is not None: if types is not None:
text_inputs = text_inputs * (1+types).unsqueeze(-1) text_inputs = text_inputs * (1+types).unsqueeze(-1)
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token)
mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token)
conds = self.get_conditioning_latent(speech_conditioning_input) conds = self.get_conditioning_latent(speech_conditioning_input)
ac_expansion_factor = mel_codes.shape[-1] // aligned_codes.shape[-1] ac_expansion_factor = mel_codes.shape[-1] / aligned_codes.shape[-1]
aligned_codes = aligned_codes.repeat(1, ac_expansion_factor) aligned_codes = aligned_codes.repeat(1, ac_expansion_factor)
_, aligned_targets = self.build_aligned_inputs_and_targets(aligned_codes, 0, 0) _, aligned_targets = self.build_aligned_inputs_and_targets(aligned_codes, 0, 0)
text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token) mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
mel_inp = mel_codes mel_inp = mel_codes
mel_emb = self.mel_embedding(mel_inp) mel_emb = self.mel_embedding(mel_inp)
@ -431,15 +440,16 @@ class UnifiedVoice(nn.Module):
@register_model @register_model
def register_unified_voice2(opt_net, opt): def register_unified_voice3(opt_net, opt):
return UnifiedVoice(**opt_get(opt_net, ['kwargs'], {})) return UnifiedVoice(**opt_get(opt_net, ['kwargs'], {}))
if __name__ == '__main__': if __name__ == '__main__':
gpt = UnifiedVoice(model_dim=256, heads=4, max_conditioning_inputs=4, types=2) gpt = UnifiedVoice(model_dim=256, heads=4, max_conditioning_inputs=4, types=2)
mel = torch.randint(high=8192, size=(2,250))
ac = torch.randint(high=256, size=(2,250*1024//443))
l = gpt(torch.randn(2, 3, 80, 800), l = gpt(torch.randn(2, 3, 80, 800),
torch.randint(high=256, size=(2,120)), torch.randint(high=256, size=(2,120)),
torch.tensor([32, 120]), torch.tensor([32, 120]),
torch.randint(high=8192, size=(2,250)), mel, torch.tensor([250*256,195*256]), ac,
torch.tensor([250*256,195*256]),
types=torch.tensor([0, 1])) types=torch.tensor([0, 1]))