fix bug and allow position encodings to be trained separately from the rest of the model

This commit is contained in:
James Betker 2022-04-08 16:26:01 -06:00
parent 09ab1aa9bc
commit 032983e2ed

View File

@ -241,7 +241,7 @@ class UnifiedVoice(nn.Module):
mel_length_compression=1024, number_text_tokens=256, mel_length_compression=1024, number_text_tokens=256,
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192, start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True, stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
checkpointing=True, average_conditioning_embeddings=False): checkpointing=True, average_conditioning_embeddings=False, freeze_everything_but_position_embeddings=False):
""" """
Args: Args:
layers: Number of layers in transformer stack. layers: Number of layers in transformer stack.
@ -272,10 +272,10 @@ class UnifiedVoice(nn.Module):
self.stop_mel_token = stop_mel_token self.stop_mel_token = stop_mel_token
self.layers = layers self.layers = layers
self.heads = heads self.heads = heads
self.max_conditioning_inputs = max_conditioning_inputs
self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens+2+self.max_conditioning_inputs self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens+2+self.max_conditioning_inputs
self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens+2 self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens+2
self.model_dim = model_dim self.model_dim = model_dim
self.max_conditioning_inputs = max_conditioning_inputs
self.mel_length_compression = mel_length_compression self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.average_conditioning_embeddings = average_conditioning_embeddings self.average_conditioning_embeddings = average_conditioning_embeddings
@ -304,6 +304,15 @@ class UnifiedVoice(nn.Module):
for module in embeddings: for module in embeddings:
module.weight.data.normal_(mean=0.0, std=.02) module.weight.data.normal_(mean=0.0, std=.02)
if freeze_everything_but_position_embeddings:
for p in self.parameters():
p.requires_grad = False
p.DO_NOT_TRAIN = True
for m in [self.mel_pos_embedding, self.text_pos_embedding]:
for p in m.parameters():
del p.DO_NOT_TRAIN
p.requires_grad = True
def get_grad_norm_parameter_groups(self): def get_grad_norm_parameter_groups(self):
return { return {
'conditioning_encoder': list(self.conditioning_encoder.parameters()), 'conditioning_encoder': list(self.conditioning_encoder.parameters()),
@ -566,7 +575,7 @@ def register_unified_voice2(opt_net, opt):
if __name__ == '__main__': if __name__ == '__main__':
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4) gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4, freeze_everything_but_position_embeddings=True)
l = gpt(torch.randn(2, 3, 80, 800), l = gpt(torch.randn(2, 3, 80, 800),
torch.randint(high=256, size=(2,120)), torch.randint(high=256, size=(2,120)),
torch.tensor([32, 120]), torch.tensor([32, 120]),