uv3
This commit is contained in:
parent
eb64d18075
commit
ee218ab9b7
|
@ -238,7 +238,8 @@ class MelEncoder(nn.Module):
|
|||
class UnifiedVoice(nn.Module):
|
||||
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
|
||||
mel_length_compression=1024, number_text_tokens=256, number_mel_codes=8194, start_mel_token=8192,
|
||||
stop_mel_token=8193, start_text_token=None, number_aligned_text_codes=256, checkpointing=True, types=1):
|
||||
stop_mel_token=8193, start_text_token=None, number_aligned_text_codes=256, checkpointing=True, types=1,
|
||||
freeze_for_aligned_codes=False,):
|
||||
"""
|
||||
Args:
|
||||
layers: Number of layers in transformer stack.
|
||||
|
@ -278,13 +279,21 @@ class UnifiedVoice(nn.Module):
|
|||
self.final_norm = nn.LayerNorm(model_dim)
|
||||
self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1)
|
||||
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
|
||||
self.aligned_head = nn.Linear(model_dim, self.number_aligned_text_codes)
|
||||
self.aligned_head = nn.Linear(model_dim, number_aligned_text_codes)
|
||||
|
||||
# Initialize the embeddings per the GPT-2 scheme
|
||||
embeddings = [self.text_embedding, self.mel_embedding]
|
||||
for module in embeddings:
|
||||
module.weight.data.normal_(mean=0.0, std=.02)
|
||||
|
||||
if freeze_for_aligned_codes:
|
||||
for p in self.parameters():
|
||||
p.DO_NOT_TRAIN = True
|
||||
p.requires_grad = False
|
||||
for p in self.aligned_head.parameters():
|
||||
del p.DO_NOT_TRAIN
|
||||
p.requires_grad = True
|
||||
|
||||
def get_grad_norm_parameter_groups(self):
|
||||
return {
|
||||
'conditioning_encoder': list(self.conditioning_encoder.parameters()),
|
||||
|
@ -363,18 +372,18 @@ class UnifiedVoice(nn.Module):
|
|||
if types is not None:
|
||||
text_inputs = text_inputs * (1+types).unsqueeze(-1)
|
||||
|
||||
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
|
||||
text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token)
|
||||
mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token)
|
||||
|
||||
conds = self.get_conditioning_latent(speech_conditioning_input)
|
||||
|
||||
ac_expansion_factor = mel_codes.shape[-1] // aligned_codes.shape[-1]
|
||||
ac_expansion_factor = mel_codes.shape[-1] / aligned_codes.shape[-1]
|
||||
aligned_codes = aligned_codes.repeat(1, ac_expansion_factor)
|
||||
_, aligned_targets = self.build_aligned_inputs_and_targets(aligned_codes, 0, 0)
|
||||
|
||||
text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token)
|
||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
|
||||
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
|
||||
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
|
||||
mel_inp = mel_codes
|
||||
mel_emb = self.mel_embedding(mel_inp)
|
||||
|
@ -431,15 +440,16 @@ class UnifiedVoice(nn.Module):
|
|||
|
||||
|
||||
@register_model
|
||||
def register_unified_voice2(opt_net, opt):
|
||||
def register_unified_voice3(opt_net, opt):
|
||||
return UnifiedVoice(**opt_get(opt_net, ['kwargs'], {}))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
gpt = UnifiedVoice(model_dim=256, heads=4, max_conditioning_inputs=4, types=2)
|
||||
mel = torch.randint(high=8192, size=(2,250))
|
||||
ac = torch.randint(high=256, size=(2,250*1024//443))
|
||||
l = gpt(torch.randn(2, 3, 80, 800),
|
||||
torch.randint(high=256, size=(2,120)),
|
||||
torch.tensor([32, 120]),
|
||||
torch.randint(high=8192, size=(2,250)),
|
||||
torch.tensor([250*256,195*256]),
|
||||
mel, torch.tensor([250*256,195*256]), ac,
|
||||
types=torch.tensor([0, 1]))
|
||||
|
|
Loading…
Reference in New Issue
Block a user