mdf: provide conditioning margin

This commit is contained in:
James Betker 2022-07-14 21:38:14 -06:00
parent 4d5688be47
commit 3b12d348fc

View File

@ -230,9 +230,10 @@ class MusicDiffusionFid(evaluator.Evaluator):
# 1. Generate the cheater latent using the input as a reference.
sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop
gen_cheater = sampler(self.model, cheater.shape, progress=True,
output_shape = (1, 256, cheater.shape[-1]-80)
gen_cheater = sampler(self.model, output_shape, progress=True,
causal=self.causal, causal_slope=self.causal_slope,
model_kwargs={'conditioning_input': cheater})
model_kwargs={'conditioning_input': cheater, 'cond_start': 40})
# 2. Decode the cheater into a MEL
gen_mel = self.cheater_decoder_diffuser.ddim_sample_loop(self.local_modules['cheater_decoder'].diff.to(audio.device), (1,256,gen_cheater.shape[-1]*16), progress=True,