forked from mrq/DL-Art-School
allow freezing the upper quantizer
This commit is contained in:
parent
43f225c35c
commit
16936881e5
|
@ -58,9 +58,11 @@ class UpperConditioningEncoder(nn.Module):
|
|||
|
||||
|
||||
class GptMusicLower(nn.Module):
|
||||
def __init__(self, dim, layers, dropout=0, num_target_vectors=512, num_target_groups=2, num_upper_vectors=64, num_upper_groups=4, fp16=True):
|
||||
def __init__(self, dim, layers, dropout=0, num_target_vectors=512, num_target_groups=2, num_upper_vectors=64,
|
||||
num_upper_groups=4, fp16=True, freeze_upper_until=0):
|
||||
super().__init__()
|
||||
self.internal_step = 0
|
||||
self.freeze_upper_until = freeze_upper_until
|
||||
self.num_groups = num_target_groups
|
||||
self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64,
|
||||
n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True, use_cache=False)
|
||||
|
@ -103,10 +105,16 @@ class GptMusicLower(nn.Module):
|
|||
|
||||
|
||||
def forward(self, mel, conditioning, return_latent=False):
|
||||
unused_params = []
|
||||
with torch.no_grad():
|
||||
self.target_quantizer.eval()
|
||||
codes = self.target_quantizer.get_codes(mel)
|
||||
upper_vector, upper_diversity = self.upper_quantizer(mel, return_decoder_latent=True)
|
||||
if self.freeze_upper_until > self.internal_step:
|
||||
with torch.no_grad():
|
||||
upper_vector, upper_diversity = self.upper_quantizer(mel, return_decoder_latent=True)
|
||||
unused_params.extend(list(self.upper_quantizer.parameters()))
|
||||
else:
|
||||
upper_vector, upper_diversity = self.upper_quantizer(mel, return_decoder_latent=True)
|
||||
upper_vector = self.upper_mixer(upper_vector.permute(0,2,1)).permute(0,2,1) # Allow the upper vector to fully attend to itself (the whole thing is a prior.)
|
||||
upper_vector = F.interpolate(upper_vector, size=codes.shape[1], mode='linear')
|
||||
upper_vector = upper_vector.permute(0,2,1)
|
||||
|
@ -135,6 +143,11 @@ class GptMusicLower(nn.Module):
|
|||
loss = F.cross_entropy(logits, targets[:,:,i])
|
||||
losses = losses + loss
|
||||
|
||||
unused_adder = 0
|
||||
for p in unused_params:
|
||||
unused_adder = unused_adder + p.mean() * 0
|
||||
losses = losses + unused_adder
|
||||
|
||||
return losses / self.num_groups, upper_diversity
|
||||
|
||||
def get_grad_norm_parameter_groups(self):
|
||||
|
@ -256,7 +269,7 @@ def test_lower():
|
|||
dropout=.1, unconditioned_percentage=0, freeze_quantizer_until=6000)
|
||||
base_diff.load_state_dict(torch.load('x:/dlas/experiments/train_music_diffusion_tfd8/models/47500_generator.pth', map_location=torch.device('cpu')))
|
||||
|
||||
model = GptMusicLower(512, 8, fp16=False)
|
||||
model = GptMusicLower(512, 8, fp16=False, freeze_upper_until=100)
|
||||
model.target_quantizer.load_state_dict(base_diff.quantizer.state_dict(), strict=False)
|
||||
torch.save(model.state_dict(), "sample.pth")
|
||||
mel = torch.randn(2,256,400)
|
||||
|
|
Loading…
Reference in New Issue
Block a user