fix bug and allow position encodings to be trained separately from the rest of the model
This commit is contained in:
parent
09ab1aa9bc
commit
032983e2ed
|
@ -241,7 +241,7 @@ class UnifiedVoice(nn.Module):
|
|||
mel_length_compression=1024, number_text_tokens=256,
|
||||
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
|
||||
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
|
||||
checkpointing=True, average_conditioning_embeddings=False):
|
||||
checkpointing=True, average_conditioning_embeddings=False, freeze_everything_but_position_embeddings=False):
|
||||
"""
|
||||
Args:
|
||||
layers: Number of layers in transformer stack.
|
||||
|
@ -272,10 +272,10 @@ class UnifiedVoice(nn.Module):
|
|||
self.stop_mel_token = stop_mel_token
|
||||
self.layers = layers
|
||||
self.heads = heads
|
||||
self.max_conditioning_inputs = max_conditioning_inputs
|
||||
self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens+2+self.max_conditioning_inputs
|
||||
self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens+2
|
||||
self.model_dim = model_dim
|
||||
self.max_conditioning_inputs = max_conditioning_inputs
|
||||
self.mel_length_compression = mel_length_compression
|
||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
||||
self.average_conditioning_embeddings = average_conditioning_embeddings
|
||||
|
@ -304,6 +304,15 @@ class UnifiedVoice(nn.Module):
|
|||
for module in embeddings:
|
||||
module.weight.data.normal_(mean=0.0, std=.02)
|
||||
|
||||
if freeze_everything_but_position_embeddings:
|
||||
for p in self.parameters():
|
||||
p.requires_grad = False
|
||||
p.DO_NOT_TRAIN = True
|
||||
for m in [self.mel_pos_embedding, self.text_pos_embedding]:
|
||||
for p in m.parameters():
|
||||
del p.DO_NOT_TRAIN
|
||||
p.requires_grad = True
|
||||
|
||||
def get_grad_norm_parameter_groups(self):
|
||||
return {
|
||||
'conditioning_encoder': list(self.conditioning_encoder.parameters()),
|
||||
|
@ -566,7 +575,7 @@ def register_unified_voice2(opt_net, opt):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4)
|
||||
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4, freeze_everything_but_position_embeddings=True)
|
||||
l = gpt(torch.randn(2, 3, 80, 800),
|
||||
torch.randint(high=256, size=(2,120)),
|
||||
torch.tensor([32, 120]),
|
||||
|
|
Loading…
Reference in New Issue
Block a user