forked from mrq/DL-Art-School
a few fixes
This commit is contained in:
parent
712e0e82f7
commit
b2a83efe50
|
@ -227,12 +227,17 @@ class TransformerDiffusionWithQuantizer(nn.Module):
|
|||
self.m2v.min_gumbel_temperature,
|
||||
)
|
||||
|
||||
def forward(self, x, timesteps, truth_mel, conditioning_input, conditioning_free=False):
|
||||
def forward(self, x, timesteps, truth_mel, conditioning_input, disable_diversity=False, conditioning_free=False):
|
||||
proj = self.m2v.m2v.input_blocks(truth_mel).permute(0,2,1)
|
||||
proj = self.m2v.m2v.projector.layer_norm(proj)
|
||||
vectors, _, probs = self.m2v.quantizer(proj, return_probs=True)
|
||||
vectors, perplexity, probs = self.m2v.quantizer(proj, return_probs=True)
|
||||
diversity = (self.m2v.quantizer.num_codevectors - perplexity) / self.m2v.quantizer.num_codevectors
|
||||
self.log_codes(probs)
|
||||
return self.diff(x, timesteps, codes=vectors, conditioning_input=conditioning_input, conditioning_free=conditioning_free)
|
||||
diff = self.diff(x, timesteps, codes=vectors, conditioning_input=conditioning_input, conditioning_free=conditioning_free)
|
||||
if disable_diversity:
|
||||
return diff
|
||||
else:
|
||||
return diff, diversity
|
||||
|
||||
def log_codes(self, codes):
|
||||
if self.internal_step % 5 == 0:
|
||||
|
|
|
@ -217,7 +217,7 @@ class TransformerDiffusionWithQuantizer(nn.Module):
|
|||
self.m2v.min_gumbel_temperature,
|
||||
)
|
||||
|
||||
def forward(self, x, timesteps, truth_mel, conditioning_input, conditioning_free=False):
|
||||
def forward(self, x, timesteps, truth_mel, conditioning_input, disable_diversity=False, conditioning_free=False):
|
||||
quant_grad_enabled = self.internal_step > self.freeze_quantizer_until
|
||||
with torch.set_grad_enabled(quant_grad_enabled):
|
||||
proj, diversity_loss = self.m2v(truth_mel, return_decoder_latent=True)
|
||||
|
@ -231,8 +231,10 @@ class TransformerDiffusionWithQuantizer(nn.Module):
|
|||
proj = proj + unused
|
||||
diversity_loss = diversity_loss * 0
|
||||
|
||||
return self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input,
|
||||
conditioning_free=conditioning_free), diversity_loss
|
||||
diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, conditioning_free=conditioning_free)
|
||||
if disable_diversity:
|
||||
return diff
|
||||
return diff, diversity_loss
|
||||
|
||||
def get_debug_values(self, step, __):
|
||||
if self.m2v.total_codes > 0:
|
||||
|
|
|
@ -114,9 +114,15 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
|
||||
mel = self.spec_fn({'in': audio})['out']
|
||||
mel_norm = normalize_mel(mel)
|
||||
gen_mel = self.diffuser.p_sample_loop(self.model, mel_norm.shape,
|
||||
def denoising_fn(x):
|
||||
q9 = torch.quantile(x, q=.95, dim=-1).unsqueeze(-1)
|
||||
s = q9.clamp(1, 9999999999)
|
||||
x = x.clamp(-s, s) / s
|
||||
return x
|
||||
gen_mel = self.diffuser.p_sample_loop(self.model, mel_norm.shape, denoised_fn=denoising_fn, clip_denoised=False,
|
||||
model_kwargs={'truth_mel': mel,
|
||||
'conditioning_input': torch.zeros_like(mel_norm[:,:,:390])})
|
||||
'conditioning_input': torch.zeros_like(mel_norm[:,:,:390]),
|
||||
'disable_diversity': True})
|
||||
|
||||
gen_mel_denorm = denormalize_mel(gen_mel)
|
||||
output_shape = (1,16,audio.shape[-1]//16)
|
||||
|
@ -163,7 +169,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
gen_projections = []
|
||||
real_projections = []
|
||||
for i in tqdm(list(range(0, len(self.data), self.skip))):
|
||||
path = self.data[i + self.env['rank']]
|
||||
path = self.data[(i + self.env['rank']) % len(self.data)]
|
||||
audio = load_audio(path, 22050).to(self.dev)
|
||||
audio = audio[:, :22050*10]
|
||||
sample, ref, sample_mel, ref_mel, sample_rate = self.diffusion_fn(audio)
|
||||
|
@ -195,13 +201,13 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_tfd_quant.yml', 'generator',
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_tfd5_quant\\train.yml', 'generator',
|
||||
also_load_savepoint=False,
|
||||
load_path='X:\\dlas\\experiments\\music_tfd5_with_quantizer_basis.pth'
|
||||
load_path='X:\\dlas\\experiments\\train_music_diffusion_tfd5_quant\\models\\27000_generator_ema.pth'
|
||||
).cuda()
|
||||
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 100,
|
||||
'conditioning_free': False, 'conditioning_free_k': 2,
|
||||
'conditioning_free': True, 'conditioning_free_k': 1,
|
||||
'diffusion_schedule': 'linear', 'diffusion_type': 'from_codes_quant'}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 558, 'device': 'cuda', 'opt': {}}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 560, 'device': 'cuda', 'opt': {}}
|
||||
eval = MusicDiffusionFid(diffusion, opt_eval, env)
|
||||
print(eval.perform_eval())
|
||||
|
|
Loading…
Reference in New Issue
Block a user