allow freezing the upper quantizer

This commit is contained in:
James Betker 2022-06-08 18:30:22 -06:00
parent 43f225c35c
commit 16936881e5

View File

@ -58,9 +58,11 @@ class UpperConditioningEncoder(nn.Module):
class GptMusicLower(nn.Module): class GptMusicLower(nn.Module):
def __init__(self, dim, layers, dropout=0, num_target_vectors=512, num_target_groups=2, num_upper_vectors=64, num_upper_groups=4, fp16=True): def __init__(self, dim, layers, dropout=0, num_target_vectors=512, num_target_groups=2, num_upper_vectors=64,
num_upper_groups=4, fp16=True, freeze_upper_until=0):
super().__init__() super().__init__()
self.internal_step = 0 self.internal_step = 0
self.freeze_upper_until = freeze_upper_until
self.num_groups = num_target_groups self.num_groups = num_target_groups
self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64, self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64,
n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True, use_cache=False) n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True, use_cache=False)
@ -103,10 +105,16 @@ class GptMusicLower(nn.Module):
def forward(self, mel, conditioning, return_latent=False): def forward(self, mel, conditioning, return_latent=False):
unused_params = []
with torch.no_grad(): with torch.no_grad():
self.target_quantizer.eval() self.target_quantizer.eval()
codes = self.target_quantizer.get_codes(mel) codes = self.target_quantizer.get_codes(mel)
upper_vector, upper_diversity = self.upper_quantizer(mel, return_decoder_latent=True) if self.freeze_upper_until > self.internal_step:
with torch.no_grad():
upper_vector, upper_diversity = self.upper_quantizer(mel, return_decoder_latent=True)
unused_params.extend(list(self.upper_quantizer.parameters()))
else:
upper_vector, upper_diversity = self.upper_quantizer(mel, return_decoder_latent=True)
upper_vector = self.upper_mixer(upper_vector.permute(0,2,1)).permute(0,2,1) # Allow the upper vector to fully attend to itself (the whole thing is a prior.) upper_vector = self.upper_mixer(upper_vector.permute(0,2,1)).permute(0,2,1) # Allow the upper vector to fully attend to itself (the whole thing is a prior.)
upper_vector = F.interpolate(upper_vector, size=codes.shape[1], mode='linear') upper_vector = F.interpolate(upper_vector, size=codes.shape[1], mode='linear')
upper_vector = upper_vector.permute(0,2,1) upper_vector = upper_vector.permute(0,2,1)
@ -135,6 +143,11 @@ class GptMusicLower(nn.Module):
loss = F.cross_entropy(logits, targets[:,:,i]) loss = F.cross_entropy(logits, targets[:,:,i])
losses = losses + loss losses = losses + loss
unused_adder = 0
for p in unused_params:
unused_adder = unused_adder + p.mean() * 0
losses = losses + unused_adder
return losses / self.num_groups, upper_diversity return losses / self.num_groups, upper_diversity
def get_grad_norm_parameter_groups(self): def get_grad_norm_parameter_groups(self):
@ -256,7 +269,7 @@ def test_lower():
dropout=.1, unconditioned_percentage=0, freeze_quantizer_until=6000) dropout=.1, unconditioned_percentage=0, freeze_quantizer_until=6000)
base_diff.load_state_dict(torch.load('x:/dlas/experiments/train_music_diffusion_tfd8/models/47500_generator.pth', map_location=torch.device('cpu'))) base_diff.load_state_dict(torch.load('x:/dlas/experiments/train_music_diffusion_tfd8/models/47500_generator.pth', map_location=torch.device('cpu')))
model = GptMusicLower(512, 8, fp16=False) model = GptMusicLower(512, 8, fp16=False, freeze_upper_until=100)
model.target_quantizer.load_state_dict(base_diff.quantizer.state_dict(), strict=False) model.target_quantizer.load_state_dict(base_diff.quantizer.state_dict(), strict=False)
torch.save(model.state_dict(), "sample.pth") torch.save(model.state_dict(), "sample.pth")
mel = torch.randn(2,256,400) mel = torch.randn(2,256,400)