uv3
This commit is contained in:
parent
eb64d18075
commit
ee218ab9b7
|
@ -238,7 +238,8 @@ class MelEncoder(nn.Module):
|
||||||
class UnifiedVoice(nn.Module):
|
class UnifiedVoice(nn.Module):
|
||||||
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
|
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
|
||||||
mel_length_compression=1024, number_text_tokens=256, number_mel_codes=8194, start_mel_token=8192,
|
mel_length_compression=1024, number_text_tokens=256, number_mel_codes=8194, start_mel_token=8192,
|
||||||
stop_mel_token=8193, start_text_token=None, number_aligned_text_codes=256, checkpointing=True, types=1):
|
stop_mel_token=8193, start_text_token=None, number_aligned_text_codes=256, checkpointing=True, types=1,
|
||||||
|
freeze_for_aligned_codes=False,):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
layers: Number of layers in transformer stack.
|
layers: Number of layers in transformer stack.
|
||||||
|
@ -278,13 +279,21 @@ class UnifiedVoice(nn.Module):
|
||||||
self.final_norm = nn.LayerNorm(model_dim)
|
self.final_norm = nn.LayerNorm(model_dim)
|
||||||
self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1)
|
self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1)
|
||||||
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
|
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
|
||||||
self.aligned_head = nn.Linear(model_dim, self.number_aligned_text_codes)
|
self.aligned_head = nn.Linear(model_dim, number_aligned_text_codes)
|
||||||
|
|
||||||
# Initialize the embeddings per the GPT-2 scheme
|
# Initialize the embeddings per the GPT-2 scheme
|
||||||
embeddings = [self.text_embedding, self.mel_embedding]
|
embeddings = [self.text_embedding, self.mel_embedding]
|
||||||
for module in embeddings:
|
for module in embeddings:
|
||||||
module.weight.data.normal_(mean=0.0, std=.02)
|
module.weight.data.normal_(mean=0.0, std=.02)
|
||||||
|
|
||||||
|
if freeze_for_aligned_codes:
|
||||||
|
for p in self.parameters():
|
||||||
|
p.DO_NOT_TRAIN = True
|
||||||
|
p.requires_grad = False
|
||||||
|
for p in self.aligned_head.parameters():
|
||||||
|
del p.DO_NOT_TRAIN
|
||||||
|
p.requires_grad = True
|
||||||
|
|
||||||
def get_grad_norm_parameter_groups(self):
|
def get_grad_norm_parameter_groups(self):
|
||||||
return {
|
return {
|
||||||
'conditioning_encoder': list(self.conditioning_encoder.parameters()),
|
'conditioning_encoder': list(self.conditioning_encoder.parameters()),
|
||||||
|
@ -363,18 +372,18 @@ class UnifiedVoice(nn.Module):
|
||||||
if types is not None:
|
if types is not None:
|
||||||
text_inputs = text_inputs * (1+types).unsqueeze(-1)
|
text_inputs = text_inputs * (1+types).unsqueeze(-1)
|
||||||
|
|
||||||
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
|
|
||||||
text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token)
|
|
||||||
mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token)
|
|
||||||
|
|
||||||
conds = self.get_conditioning_latent(speech_conditioning_input)
|
conds = self.get_conditioning_latent(speech_conditioning_input)
|
||||||
|
|
||||||
ac_expansion_factor = mel_codes.shape[-1] // aligned_codes.shape[-1]
|
ac_expansion_factor = mel_codes.shape[-1] / aligned_codes.shape[-1]
|
||||||
aligned_codes = aligned_codes.repeat(1, ac_expansion_factor)
|
aligned_codes = aligned_codes.repeat(1, ac_expansion_factor)
|
||||||
_, aligned_targets = self.build_aligned_inputs_and_targets(aligned_codes, 0, 0)
|
_, aligned_targets = self.build_aligned_inputs_and_targets(aligned_codes, 0, 0)
|
||||||
|
|
||||||
|
text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token)
|
||||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||||
|
|
||||||
|
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
|
||||||
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
|
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
|
||||||
mel_inp = mel_codes
|
mel_inp = mel_codes
|
||||||
mel_emb = self.mel_embedding(mel_inp)
|
mel_emb = self.mel_embedding(mel_inp)
|
||||||
|
@ -431,15 +440,16 @@ class UnifiedVoice(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_unified_voice2(opt_net, opt):
|
def register_unified_voice3(opt_net, opt):
|
||||||
return UnifiedVoice(**opt_get(opt_net, ['kwargs'], {}))
|
return UnifiedVoice(**opt_get(opt_net, ['kwargs'], {}))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
gpt = UnifiedVoice(model_dim=256, heads=4, max_conditioning_inputs=4, types=2)
|
gpt = UnifiedVoice(model_dim=256, heads=4, max_conditioning_inputs=4, types=2)
|
||||||
|
mel = torch.randint(high=8192, size=(2,250))
|
||||||
|
ac = torch.randint(high=256, size=(2,250*1024//443))
|
||||||
l = gpt(torch.randn(2, 3, 80, 800),
|
l = gpt(torch.randn(2, 3, 80, 800),
|
||||||
torch.randint(high=256, size=(2,120)),
|
torch.randint(high=256, size=(2,120)),
|
||||||
torch.tensor([32, 120]),
|
torch.tensor([32, 120]),
|
||||||
torch.randint(high=8192, size=(2,250)),
|
mel, torch.tensor([250*256,195*256]), ac,
|
||||||
torch.tensor([250*256,195*256]),
|
|
||||||
types=torch.tensor([0, 1]))
|
types=torch.tensor([0, 1]))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user