forked from mrq/DL-Art-School
freeze quantizer until step
This commit is contained in:
parent
64b6ae2f4a
commit
4c6ef42b38
|
@ -199,9 +199,11 @@ class TransformerDiffusion(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class TransformerDiffusionWithQuantizer(nn.Module):
|
class TransformerDiffusionWithQuantizer(nn.Module):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, freeze_quantizer_until=20000, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.internal_step = 0
|
||||||
|
self.freeze_quantizer_until = freeze_quantizer_until
|
||||||
self.diff = TransformerDiffusion(**kwargs)
|
self.diff = TransformerDiffusion(**kwargs)
|
||||||
from models.audio.mel2vec import ContrastiveTrainingWrapper
|
from models.audio.mel2vec import ContrastiveTrainingWrapper
|
||||||
self.m2v = MusicQuantizer(inp_channels=256, inner_dim=2048, codevector_dim=1024)
|
self.m2v = MusicQuantizer(inp_channels=256, inner_dim=2048, codevector_dim=1024)
|
||||||
|
@ -210,13 +212,24 @@ class TransformerDiffusionWithQuantizer(nn.Module):
|
||||||
|
|
||||||
def update_for_step(self, step, *args):
|
def update_for_step(self, step, *args):
|
||||||
self.internal_step = step
|
self.internal_step = step
|
||||||
|
qstep = max(0, self.internal_step - self.freeze_quantizer_until)
|
||||||
self.m2v.quantizer.temperature = max(
|
self.m2v.quantizer.temperature = max(
|
||||||
self.m2v.max_gumbel_temperature * self.m2v.gumbel_temperature_decay**step,
|
self.m2v.max_gumbel_temperature * self.m2v.gumbel_temperature_decay**qstep,
|
||||||
self.m2v.min_gumbel_temperature,
|
self.m2v.min_gumbel_temperature,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, timesteps, truth_mel, conditioning_input, conditioning_free=False):
|
def forward(self, x, timesteps, truth_mel, conditioning_input, conditioning_free=False):
|
||||||
proj = self.m2v(truth_mel, return_decoder_latent=True).permute(0,2,1)
|
quant_grad_enabled = self.internal_step > self.freeze_quantizer_until
|
||||||
|
with torch.set_grad_enabled(quant_grad_enabled):
|
||||||
|
proj = self.m2v(truth_mel, return_decoder_latent=True).permute(0,2,1)
|
||||||
|
|
||||||
|
# Make sure this does not cause issues in DDP by explicitly using the parameters for nothing.
|
||||||
|
if not quant_grad_enabled:
|
||||||
|
unused = 0
|
||||||
|
for p in self.m2v.parameters():
|
||||||
|
unused = unused + p.mean() * 0
|
||||||
|
proj = proj + unused
|
||||||
|
|
||||||
return self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input,
|
return self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input,
|
||||||
conditioning_free=conditioning_free)
|
conditioning_free=conditioning_free)
|
||||||
|
|
||||||
|
@ -256,12 +269,12 @@ if __name__ == '__main__':
|
||||||
ts = torch.LongTensor([600, 600])
|
ts = torch.LongTensor([600, 600])
|
||||||
model = TransformerDiffusionWithQuantizer(model_channels=2048, block_channels=1024, prenet_channels=1024, input_vec_dim=2048, num_layers=16)
|
model = TransformerDiffusionWithQuantizer(model_channels=2048, block_channels=1024, prenet_channels=1024, input_vec_dim=2048, num_layers=16)
|
||||||
|
|
||||||
quant_weights = torch.load('X:\\dlas\\experiments\\train_music_quant\\models\\1000_generator.pth')
|
#quant_weights = torch.load('X:\\dlas\\experiments\\train_music_quant\\models\\1000_generator.pth')
|
||||||
#diff_weights = torch.load('X:\\dlas\\experiments\\train_music_diffusion_tfd5\\models\\48000_generator_ema.pth')
|
#diff_weights = torch.load('X:\\dlas\\experiments\\train_music_diffusion_tfd5\\models\\48000_generator_ema.pth')
|
||||||
model.m2v.load_state_dict(quant_weights, strict=False)
|
#model.m2v.load_state_dict(quant_weights, strict=False)
|
||||||
#model.diff.load_state_dict(diff_weights)
|
#model.diff.load_state_dict(diff_weights)
|
||||||
|
|
||||||
torch.save(model.state_dict(), 'sample.pth')
|
#torch.save(model.state_dict(), 'sample.pth')
|
||||||
print_network(model)
|
print_network(model)
|
||||||
o = model(clip, ts, clip, cond)
|
o = model(clip, ts, clip, cond)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user