a few fixes

This commit is contained in:
James Betker 2022-06-01 16:35:15 -06:00
parent 712e0e82f7
commit b2a83efe50
3 changed files with 26 additions and 13 deletions

View File

@ -227,12 +227,17 @@ class TransformerDiffusionWithQuantizer(nn.Module):
self.m2v.min_gumbel_temperature,
)
def forward(self, x, timesteps, truth_mel, conditioning_input, conditioning_free=False):
def forward(self, x, timesteps, truth_mel, conditioning_input, disable_diversity=False, conditioning_free=False):
proj = self.m2v.m2v.input_blocks(truth_mel).permute(0,2,1)
proj = self.m2v.m2v.projector.layer_norm(proj)
vectors, _, probs = self.m2v.quantizer(proj, return_probs=True)
vectors, perplexity, probs = self.m2v.quantizer(proj, return_probs=True)
diversity = (self.m2v.quantizer.num_codevectors - perplexity) / self.m2v.quantizer.num_codevectors
self.log_codes(probs)
return self.diff(x, timesteps, codes=vectors, conditioning_input=conditioning_input, conditioning_free=conditioning_free)
diff = self.diff(x, timesteps, codes=vectors, conditioning_input=conditioning_input, conditioning_free=conditioning_free)
if disable_diversity:
return diff
else:
return diff, diversity
def log_codes(self, codes):
if self.internal_step % 5 == 0:

View File

@ -217,7 +217,7 @@ class TransformerDiffusionWithQuantizer(nn.Module):
self.m2v.min_gumbel_temperature,
)
def forward(self, x, timesteps, truth_mel, conditioning_input, conditioning_free=False):
def forward(self, x, timesteps, truth_mel, conditioning_input, disable_diversity=False, conditioning_free=False):
quant_grad_enabled = self.internal_step > self.freeze_quantizer_until
with torch.set_grad_enabled(quant_grad_enabled):
proj, diversity_loss = self.m2v(truth_mel, return_decoder_latent=True)
@ -231,8 +231,10 @@ class TransformerDiffusionWithQuantizer(nn.Module):
proj = proj + unused
diversity_loss = diversity_loss * 0
return self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input,
conditioning_free=conditioning_free), diversity_loss
diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, conditioning_free=conditioning_free)
if disable_diversity:
return diff
return diff, diversity_loss
def get_debug_values(self, step, __):
if self.m2v.total_codes > 0:

View File

@ -114,9 +114,15 @@ class MusicDiffusionFid(evaluator.Evaluator):
mel = self.spec_fn({'in': audio})['out']
mel_norm = normalize_mel(mel)
gen_mel = self.diffuser.p_sample_loop(self.model, mel_norm.shape,
def denoising_fn(x):
q9 = torch.quantile(x, q=.95, dim=-1).unsqueeze(-1)
s = q9.clamp(1, 9999999999)
x = x.clamp(-s, s) / s
return x
gen_mel = self.diffuser.p_sample_loop(self.model, mel_norm.shape, denoised_fn=denoising_fn, clip_denoised=False,
model_kwargs={'truth_mel': mel,
'conditioning_input': torch.zeros_like(mel_norm[:,:,:390])})
'conditioning_input': torch.zeros_like(mel_norm[:,:,:390]),
'disable_diversity': True})
gen_mel_denorm = denormalize_mel(gen_mel)
output_shape = (1,16,audio.shape[-1]//16)
@ -163,7 +169,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
gen_projections = []
real_projections = []
for i in tqdm(list(range(0, len(self.data), self.skip))):
path = self.data[i + self.env['rank']]
path = self.data[(i + self.env['rank']) % len(self.data)]
audio = load_audio(path, 22050).to(self.dev)
audio = audio[:, :22050*10]
sample, ref, sample_mel, ref_mel, sample_rate = self.diffusion_fn(audio)
@ -195,13 +201,13 @@ class MusicDiffusionFid(evaluator.Evaluator):
if __name__ == '__main__':
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_tfd_quant.yml', 'generator',
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_tfd5_quant\\train.yml', 'generator',
also_load_savepoint=False,
load_path='X:\\dlas\\experiments\\music_tfd5_with_quantizer_basis.pth'
load_path='X:\\dlas\\experiments\\train_music_diffusion_tfd5_quant\\models\\27000_generator_ema.pth'
).cuda()
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 100,
'conditioning_free': False, 'conditioning_free_k': 2,
'conditioning_free': True, 'conditioning_free_k': 1,
'diffusion_schedule': 'linear', 'diffusion_type': 'from_codes_quant'}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 558, 'device': 'cuda', 'opt': {}}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 560, 'device': 'cuda', 'opt': {}}
eval = MusicDiffusionFid(diffusion, opt_eval, env)
print(eval.perform_eval())