forked from mrq/DL-Art-School
allow freezing the upper quantizer
This commit is contained in:
parent
43f225c35c
commit
16936881e5
|
@ -58,9 +58,11 @@ class UpperConditioningEncoder(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class GptMusicLower(nn.Module):
|
class GptMusicLower(nn.Module):
|
||||||
def __init__(self, dim, layers, dropout=0, num_target_vectors=512, num_target_groups=2, num_upper_vectors=64, num_upper_groups=4, fp16=True):
|
def __init__(self, dim, layers, dropout=0, num_target_vectors=512, num_target_groups=2, num_upper_vectors=64,
|
||||||
|
num_upper_groups=4, fp16=True, freeze_upper_until=0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.internal_step = 0
|
self.internal_step = 0
|
||||||
|
self.freeze_upper_until = freeze_upper_until
|
||||||
self.num_groups = num_target_groups
|
self.num_groups = num_target_groups
|
||||||
self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64,
|
self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64,
|
||||||
n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True, use_cache=False)
|
n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True, use_cache=False)
|
||||||
|
@ -103,10 +105,16 @@ class GptMusicLower(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
def forward(self, mel, conditioning, return_latent=False):
|
def forward(self, mel, conditioning, return_latent=False):
|
||||||
|
unused_params = []
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.target_quantizer.eval()
|
self.target_quantizer.eval()
|
||||||
codes = self.target_quantizer.get_codes(mel)
|
codes = self.target_quantizer.get_codes(mel)
|
||||||
upper_vector, upper_diversity = self.upper_quantizer(mel, return_decoder_latent=True)
|
if self.freeze_upper_until > self.internal_step:
|
||||||
|
with torch.no_grad():
|
||||||
|
upper_vector, upper_diversity = self.upper_quantizer(mel, return_decoder_latent=True)
|
||||||
|
unused_params.extend(list(self.upper_quantizer.parameters()))
|
||||||
|
else:
|
||||||
|
upper_vector, upper_diversity = self.upper_quantizer(mel, return_decoder_latent=True)
|
||||||
upper_vector = self.upper_mixer(upper_vector.permute(0,2,1)).permute(0,2,1) # Allow the upper vector to fully attend to itself (the whole thing is a prior.)
|
upper_vector = self.upper_mixer(upper_vector.permute(0,2,1)).permute(0,2,1) # Allow the upper vector to fully attend to itself (the whole thing is a prior.)
|
||||||
upper_vector = F.interpolate(upper_vector, size=codes.shape[1], mode='linear')
|
upper_vector = F.interpolate(upper_vector, size=codes.shape[1], mode='linear')
|
||||||
upper_vector = upper_vector.permute(0,2,1)
|
upper_vector = upper_vector.permute(0,2,1)
|
||||||
|
@ -135,6 +143,11 @@ class GptMusicLower(nn.Module):
|
||||||
loss = F.cross_entropy(logits, targets[:,:,i])
|
loss = F.cross_entropy(logits, targets[:,:,i])
|
||||||
losses = losses + loss
|
losses = losses + loss
|
||||||
|
|
||||||
|
unused_adder = 0
|
||||||
|
for p in unused_params:
|
||||||
|
unused_adder = unused_adder + p.mean() * 0
|
||||||
|
losses = losses + unused_adder
|
||||||
|
|
||||||
return losses / self.num_groups, upper_diversity
|
return losses / self.num_groups, upper_diversity
|
||||||
|
|
||||||
def get_grad_norm_parameter_groups(self):
|
def get_grad_norm_parameter_groups(self):
|
||||||
|
@ -256,7 +269,7 @@ def test_lower():
|
||||||
dropout=.1, unconditioned_percentage=0, freeze_quantizer_until=6000)
|
dropout=.1, unconditioned_percentage=0, freeze_quantizer_until=6000)
|
||||||
base_diff.load_state_dict(torch.load('x:/dlas/experiments/train_music_diffusion_tfd8/models/47500_generator.pth', map_location=torch.device('cpu')))
|
base_diff.load_state_dict(torch.load('x:/dlas/experiments/train_music_diffusion_tfd8/models/47500_generator.pth', map_location=torch.device('cpu')))
|
||||||
|
|
||||||
model = GptMusicLower(512, 8, fp16=False)
|
model = GptMusicLower(512, 8, fp16=False, freeze_upper_until=100)
|
||||||
model.target_quantizer.load_state_dict(base_diff.quantizer.state_dict(), strict=False)
|
model.target_quantizer.load_state_dict(base_diff.quantizer.state_dict(), strict=False)
|
||||||
torch.save(model.state_dict(), "sample.pth")
|
torch.save(model.state_dict(), "sample.pth")
|
||||||
mel = torch.randn(2,256,400)
|
mel = torch.randn(2,256,400)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user