allow freezing the upper quantizer

This commit is contained in:
James Betker 2022-06-08 18:30:22 -06:00
parent 43f225c35c
commit 16936881e5

View File

@ -58,9 +58,11 @@ class UpperConditioningEncoder(nn.Module):
class GptMusicLower(nn.Module):
def __init__(self, dim, layers, dropout=0, num_target_vectors=512, num_target_groups=2, num_upper_vectors=64, num_upper_groups=4, fp16=True):
def __init__(self, dim, layers, dropout=0, num_target_vectors=512, num_target_groups=2, num_upper_vectors=64,
num_upper_groups=4, fp16=True, freeze_upper_until=0):
super().__init__()
self.internal_step = 0
self.freeze_upper_until = freeze_upper_until
self.num_groups = num_target_groups
self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64,
n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True, use_cache=False)
@ -103,10 +105,16 @@ class GptMusicLower(nn.Module):
def forward(self, mel, conditioning, return_latent=False):
unused_params = []
with torch.no_grad():
self.target_quantizer.eval()
codes = self.target_quantizer.get_codes(mel)
upper_vector, upper_diversity = self.upper_quantizer(mel, return_decoder_latent=True)
if self.freeze_upper_until > self.internal_step:
with torch.no_grad():
upper_vector, upper_diversity = self.upper_quantizer(mel, return_decoder_latent=True)
unused_params.extend(list(self.upper_quantizer.parameters()))
else:
upper_vector, upper_diversity = self.upper_quantizer(mel, return_decoder_latent=True)
upper_vector = self.upper_mixer(upper_vector.permute(0,2,1)).permute(0,2,1) # Allow the upper vector to fully attend to itself (the whole thing is a prior.)
upper_vector = F.interpolate(upper_vector, size=codes.shape[1], mode='linear')
upper_vector = upper_vector.permute(0,2,1)
@ -135,6 +143,11 @@ class GptMusicLower(nn.Module):
loss = F.cross_entropy(logits, targets[:,:,i])
losses = losses + loss
unused_adder = 0
for p in unused_params:
unused_adder = unused_adder + p.mean() * 0
losses = losses + unused_adder
return losses / self.num_groups, upper_diversity
def get_grad_norm_parameter_groups(self):
@ -256,7 +269,7 @@ def test_lower():
dropout=.1, unconditioned_percentage=0, freeze_quantizer_until=6000)
base_diff.load_state_dict(torch.load('x:/dlas/experiments/train_music_diffusion_tfd8/models/47500_generator.pth', map_location=torch.device('cpu')))
model = GptMusicLower(512, 8, fp16=False)
model = GptMusicLower(512, 8, fp16=False, freeze_upper_until=100)
model.target_quantizer.load_state_dict(base_diff.quantizer.state_dict(), strict=False)
torch.save(model.state_dict(), "sample.pth")
mel = torch.randn(2,256,400)