try to make tfd8 be able to be trained e2e in quantizer mode

This commit is contained in:
James Betker 2022-06-10 10:40:56 -06:00
parent e78c4b422c
commit 97b32dd39d
2 changed files with 37 additions and 27 deletions

View File

@ -224,14 +224,18 @@ class MusicQuantizer2(nn.Module):
diversity = (self.quantizer.num_codevectors - perplexity) / self.quantizer.num_codevectors
self.log_codes(codes)
h = self.decoder(codevectors.permute(0,2,1))
if return_decoder_latent:
return h, diversity
if not hasattr(self, 'up') and return_decoder_latent:
return None, diversity, h
reconstructed = self.up(h.float())
reconstructed = reconstructed[:, :, :orig_mel.shape[-1]]
mse = F.mse_loss(reconstructed, orig_mel)
return mse, diversity
if return_decoder_latent:
return mse, diversity, h
else:
return mse, diversity
def log_codes(self, codes):
if self.internal_step % 5 == 0:

View File

@ -196,16 +196,19 @@ class TransformerDiffusion(nn.Module):
class TransformerDiffusionWithQuantizer(nn.Module):
def __init__(self, freeze_quantizer_until=20000, **kwargs):
def __init__(self, freeze_quantizer_until=20000, quantizer_dims=[1024], no_reconstruction=True, **kwargs):
super().__init__()
self.internal_step = 0
self.freeze_quantizer_until = freeze_quantizer_until
self.diff = TransformerDiffusion(**kwargs)
self.quantizer = MusicQuantizer2(inp_channels=256, inner_dim=[1024], codevector_dim=1024, codebook_size=256,
codebook_groups=2, max_gumbel_temperature=4, min_gumbel_temperature=.5)
self.quantizer = MusicQuantizer2(inp_channels=kwargs['in_channels'], inner_dim=quantizer_dims,
codevector_dim=quantizer_dims[0],
codebook_size=256, codebook_groups=2,
max_gumbel_temperature=4, min_gumbel_temperature=.5)
self.quantizer.quantizer.temperature = self.quantizer.min_gumbel_temperature
del self.quantizer.up
if no_reconstruction:
del self.quantizer.up
def update_for_step(self, step, *args):
self.internal_step = step
@ -217,26 +220,28 @@ class TransformerDiffusionWithQuantizer(nn.Module):
def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False):
quant_grad_enabled = self.internal_step > self.freeze_quantizer_until
with torch.set_grad_enabled(quant_grad_enabled):
proj, diversity_loss = self.quantizer(truth_mel, return_decoder_latent=True)
proj = proj.permute(0,2,1)
mse, diversity_loss, proj = self.quantizer(truth_mel, return_decoder_latent=True)
proj = proj.permute(0,2,1)
# Make sure this does not cause issues in DDP by explicitly using the parameters for nothing.
if not quant_grad_enabled:
proj = proj.detach()
unused = 0
for p in self.quantizer.parameters():
unused = unused + p.mean() * 0
proj = proj + unused
diversity_loss = diversity_loss * 0
diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, conditioning_free=conditioning_free)
if disable_diversity:
return diff
return diff, diversity_loss
diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input,
conditioning_free=conditioning_free)
if mse is None:
return diff, diversity_loss
return diff, diversity_loss, mse
def get_debug_values(self, step, __):
if self.quantizer.total_codes > 0:
return {'histogram_codes': self.quantizer.codes[:self.quantizer.total_codes],
return {'histogram_quant_codes': self.quantizer.codes[:self.quantizer.total_codes],
'gumbel_temperature': self.quantizer.quantizer.temperature}
else:
return {}
@ -314,25 +319,26 @@ def register_transformer_diffusion8_with_ar_prior(opt_net, opt):
def test_quant_model():
clip = torch.randn(2, 256, 400)
cond = torch.randn(2, 256, 400)
clip = torch.randn(2, 100, 401)
ts = torch.LongTensor([600, 600])
model = TransformerDiffusionWithQuantizer(model_channels=2048, block_channels=1024, prenet_channels=1024,
input_vec_dim=1024, num_layers=16, prenet_layers=6)
model.get_grad_norm_parameter_groups()
model = TransformerDiffusionWithQuantizer(in_channels=100, out_channels=200, quantizer_dims=[1024,768,512,384],
model_channels=2048, block_channels=1024, prenet_channels=1024,
input_vec_dim=1024, num_layers=16, prenet_layers=6,
no_reconstruction=False)
#model.get_grad_norm_parameter_groups()
quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant_r4\\models\\5000_generator.pth')
#quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant_r4\\models\\5000_generator.pth')
#diff_weights = torch.load('X:\\dlas\\experiments\\train_music_diffusion_tfd5\\models\\48000_generator_ema.pth')
model.quantizer.load_state_dict(quant_weights, strict=False)
#model.quantizer.load_state_dict(quant_weights, strict=False)
#model.diff.load_state_dict(diff_weights)
torch.save(model.state_dict(), 'sample.pth')
#torch.save(model.state_dict(), 'sample.pth')
print_network(model)
o = model(clip, ts, clip, cond)
o = model(clip, ts, clip)
def test_ar_model():
clip = torch.randn(2, 256, 400)
clip = torch.randn(2, 256, 401)
cond = torch.randn(2, 256, 400)
ts = torch.LongTensor([600, 600])
model = TransformerDiffusionWithARPrior(model_channels=2048, block_channels=1024, prenet_channels=1024,
@ -355,4 +361,4 @@ def test_ar_model():
if __name__ == '__main__':
test_ar_model()
test_quant_model()