forked from mrq/DL-Art-School
a few fixes
This commit is contained in:
parent
712e0e82f7
commit
b2a83efe50
|
@ -227,12 +227,17 @@ class TransformerDiffusionWithQuantizer(nn.Module):
|
||||||
self.m2v.min_gumbel_temperature,
|
self.m2v.min_gumbel_temperature,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, timesteps, truth_mel, conditioning_input, conditioning_free=False):
|
def forward(self, x, timesteps, truth_mel, conditioning_input, disable_diversity=False, conditioning_free=False):
|
||||||
proj = self.m2v.m2v.input_blocks(truth_mel).permute(0,2,1)
|
proj = self.m2v.m2v.input_blocks(truth_mel).permute(0,2,1)
|
||||||
proj = self.m2v.m2v.projector.layer_norm(proj)
|
proj = self.m2v.m2v.projector.layer_norm(proj)
|
||||||
vectors, _, probs = self.m2v.quantizer(proj, return_probs=True)
|
vectors, perplexity, probs = self.m2v.quantizer(proj, return_probs=True)
|
||||||
|
diversity = (self.m2v.quantizer.num_codevectors - perplexity) / self.m2v.quantizer.num_codevectors
|
||||||
self.log_codes(probs)
|
self.log_codes(probs)
|
||||||
return self.diff(x, timesteps, codes=vectors, conditioning_input=conditioning_input, conditioning_free=conditioning_free)
|
diff = self.diff(x, timesteps, codes=vectors, conditioning_input=conditioning_input, conditioning_free=conditioning_free)
|
||||||
|
if disable_diversity:
|
||||||
|
return diff
|
||||||
|
else:
|
||||||
|
return diff, diversity
|
||||||
|
|
||||||
def log_codes(self, codes):
|
def log_codes(self, codes):
|
||||||
if self.internal_step % 5 == 0:
|
if self.internal_step % 5 == 0:
|
||||||
|
|
|
@ -217,7 +217,7 @@ class TransformerDiffusionWithQuantizer(nn.Module):
|
||||||
self.m2v.min_gumbel_temperature,
|
self.m2v.min_gumbel_temperature,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, timesteps, truth_mel, conditioning_input, conditioning_free=False):
|
def forward(self, x, timesteps, truth_mel, conditioning_input, disable_diversity=False, conditioning_free=False):
|
||||||
quant_grad_enabled = self.internal_step > self.freeze_quantizer_until
|
quant_grad_enabled = self.internal_step > self.freeze_quantizer_until
|
||||||
with torch.set_grad_enabled(quant_grad_enabled):
|
with torch.set_grad_enabled(quant_grad_enabled):
|
||||||
proj, diversity_loss = self.m2v(truth_mel, return_decoder_latent=True)
|
proj, diversity_loss = self.m2v(truth_mel, return_decoder_latent=True)
|
||||||
|
@ -231,8 +231,10 @@ class TransformerDiffusionWithQuantizer(nn.Module):
|
||||||
proj = proj + unused
|
proj = proj + unused
|
||||||
diversity_loss = diversity_loss * 0
|
diversity_loss = diversity_loss * 0
|
||||||
|
|
||||||
return self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input,
|
diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, conditioning_free=conditioning_free)
|
||||||
conditioning_free=conditioning_free), diversity_loss
|
if disable_diversity:
|
||||||
|
return diff
|
||||||
|
return diff, diversity_loss
|
||||||
|
|
||||||
def get_debug_values(self, step, __):
|
def get_debug_values(self, step, __):
|
||||||
if self.m2v.total_codes > 0:
|
if self.m2v.total_codes > 0:
|
||||||
|
|
|
@ -114,9 +114,15 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
||||||
|
|
||||||
mel = self.spec_fn({'in': audio})['out']
|
mel = self.spec_fn({'in': audio})['out']
|
||||||
mel_norm = normalize_mel(mel)
|
mel_norm = normalize_mel(mel)
|
||||||
gen_mel = self.diffuser.p_sample_loop(self.model, mel_norm.shape,
|
def denoising_fn(x):
|
||||||
|
q9 = torch.quantile(x, q=.95, dim=-1).unsqueeze(-1)
|
||||||
|
s = q9.clamp(1, 9999999999)
|
||||||
|
x = x.clamp(-s, s) / s
|
||||||
|
return x
|
||||||
|
gen_mel = self.diffuser.p_sample_loop(self.model, mel_norm.shape, denoised_fn=denoising_fn, clip_denoised=False,
|
||||||
model_kwargs={'truth_mel': mel,
|
model_kwargs={'truth_mel': mel,
|
||||||
'conditioning_input': torch.zeros_like(mel_norm[:,:,:390])})
|
'conditioning_input': torch.zeros_like(mel_norm[:,:,:390]),
|
||||||
|
'disable_diversity': True})
|
||||||
|
|
||||||
gen_mel_denorm = denormalize_mel(gen_mel)
|
gen_mel_denorm = denormalize_mel(gen_mel)
|
||||||
output_shape = (1,16,audio.shape[-1]//16)
|
output_shape = (1,16,audio.shape[-1]//16)
|
||||||
|
@ -163,7 +169,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
||||||
gen_projections = []
|
gen_projections = []
|
||||||
real_projections = []
|
real_projections = []
|
||||||
for i in tqdm(list(range(0, len(self.data), self.skip))):
|
for i in tqdm(list(range(0, len(self.data), self.skip))):
|
||||||
path = self.data[i + self.env['rank']]
|
path = self.data[(i + self.env['rank']) % len(self.data)]
|
||||||
audio = load_audio(path, 22050).to(self.dev)
|
audio = load_audio(path, 22050).to(self.dev)
|
||||||
audio = audio[:, :22050*10]
|
audio = audio[:, :22050*10]
|
||||||
sample, ref, sample_mel, ref_mel, sample_rate = self.diffusion_fn(audio)
|
sample, ref, sample_mel, ref_mel, sample_rate = self.diffusion_fn(audio)
|
||||||
|
@ -195,13 +201,13 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_tfd_quant.yml', 'generator',
|
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_tfd5_quant\\train.yml', 'generator',
|
||||||
also_load_savepoint=False,
|
also_load_savepoint=False,
|
||||||
load_path='X:\\dlas\\experiments\\music_tfd5_with_quantizer_basis.pth'
|
load_path='X:\\dlas\\experiments\\train_music_diffusion_tfd5_quant\\models\\27000_generator_ema.pth'
|
||||||
).cuda()
|
).cuda()
|
||||||
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 100,
|
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 100,
|
||||||
'conditioning_free': False, 'conditioning_free_k': 2,
|
'conditioning_free': True, 'conditioning_free_k': 1,
|
||||||
'diffusion_schedule': 'linear', 'diffusion_type': 'from_codes_quant'}
|
'diffusion_schedule': 'linear', 'diffusion_type': 'from_codes_quant'}
|
||||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 558, 'device': 'cuda', 'opt': {}}
|
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 560, 'device': 'cuda', 'opt': {}}
|
||||||
eval = MusicDiffusionFid(diffusion, opt_eval, env)
|
eval = MusicDiffusionFid(diffusion, opt_eval, env)
|
||||||
print(eval.perform_eval())
|
print(eval.perform_eval())
|
||||||
|
|
Loading…
Reference in New Issue
Block a user