Support causal diffusion in inference

This commit is contained in:
James Betker 2022-07-08 14:27:19 -06:00
parent ba1699cee2
commit b99af89c8f
2 changed files with 35 additions and 21 deletions

View File

@ -336,7 +336,7 @@ class GaussianDiffusion:
if self.conditioning_free:
if self.ramp_conditioning_free:
assert t.shape[0] == 1 # This should only be used in inference.
cfk = self.conditioning_free_k * (1 - self._scale_timesteps(t)[0].item() / self.num_timesteps)
cfk = self.conditioning_free_k * (1 - self._scale_timesteps(t).float().mean().item() / self.num_timesteps)
else:
cfk = self.conditioning_free_k
model_output = (1 + cfk) * model_output - cfk * model_output_no_conditioning
@ -660,9 +660,12 @@ class GaussianDiffusion:
out["pred_xstart"] * th.sqrt(alpha_bar_prev)
+ th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
if len(t.shape) == 2:
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
else:
nonzero_mask = (t != 0).float()
sample = mean_pred + nonzero_mask * sigma * noise
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
@ -710,11 +713,13 @@ class GaussianDiffusion:
shape,
noise=None,
clip_denoised=True,
causal=False,
causal_slope=1,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
progress=True,
eta=0.0,
):
"""
@ -728,6 +733,8 @@ class GaussianDiffusion:
shape,
noise=noise,
clip_denoised=clip_denoised,
causal=causal,
causal_slope=causal_slope,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
@ -744,11 +751,13 @@ class GaussianDiffusion:
shape,
noise=None,
clip_denoised=True,
causal=False,
causal_slope=1,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
progress=True,
eta=0.0,
):
"""
@ -772,8 +781,15 @@ class GaussianDiffusion:
indices = tqdm(indices)
orig_img = img
for i in indices:
t = th.tensor([i] * shape[0], device=device)
mask = torch.zeros_like(img)
if causal:
t = causal_timestep_adjustment(t, shape[-1], self.num_timesteps, causal_slope, add_jitter=False).unsqueeze(1)
mask = t == self.num_timesteps
t[mask] = self.num_timesteps-1
mask = mask.repeat(img.shape[0], img.shape[1], 1)
with th.no_grad():
out = self.ddim_sample(
model,
@ -787,6 +803,8 @@ class GaussianDiffusion:
)
yield out
img = out["sample"]
if torch.any(mask):
img[mask] = orig_img[mask] # For causal diffusion, keep resetting these predictions until they are unmasked.
def _vb_terms_bpd(
self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None

View File

@ -37,6 +37,8 @@ class MusicDiffusionFid(evaluator.Evaluator):
self.data = self.load_data(self.real_path)
self.clip = opt_get(opt_eval, ['clip_audio'], True) # Recommend setting true for more efficient eval passes.
self.ddim = opt_get(opt_eval, ['use_ddim'], False)
self.causal = opt_get(opt_eval, ['causal'], True)
self.causal_slope = opt_get(opt_eval, ['causal_slope'], 1)
if distributed.is_initialized() and distributed.get_world_size() > 1:
self.skip = distributed.get_world_size() # One batch element per GPU.
else:
@ -84,15 +86,6 @@ class MusicDiffusionFid(evaluator.Evaluator):
conditioning_free=True, conditioning_free_k=1)
self.spec_decoder = get_mel2wav_v3_model() # The only reason the other functions don't use v3 is because earlier models were trained with v1 and I want to keep metrics consistent.
self.local_modules['spec_decoder'] = self.spec_decoder
elif 'cheater_gen_fake_ar' == mode:
self.diffusion_fn = self.perform_fake_ar_reconstruction_from_cheater_gen
self.local_modules['cheater_encoder'] = get_cheater_encoder()
self.local_modules['cheater_decoder'] = get_cheater_decoder()
self.cheater_decoder_diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [32]), model_mean_type='epsilon',
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000),
conditioning_free=True, conditioning_free_k=1)
self.spec_decoder = get_mel2wav_v3_model() # The only reason the other functions don't use v3 is because earlier models were trained with v1 and I want to keep metrics consistent.
self.local_modules['spec_decoder'] = self.spec_decoder
elif 'from_ar_prior' == mode:
self.diffusion_fn = self.perform_diffusion_from_codes_ar_prior
self.local_modules['cheater_encoder'] = get_cheater_encoder()
@ -230,7 +223,9 @@ class MusicDiffusionFid(evaluator.Evaluator):
cheater = self.local_modules['cheater_encoder'].to(audio.device)(mel_norm)
# 1. Generate the cheater latent using the input as a reference.
gen_cheater = self.diffuser.ddim_sample_loop(self.model, cheater.shape, progress=True, model_kwargs={'conditioning_input': cheater})
gen_cheater = self.diffuser.ddim_sample_loop(self.model, cheater.shape, progress=True,
model_kwargs={'conditioning_input': cheater},
causal=self.causal, causal_slope=self.causal_slope)
# 2. Decode the cheater into a MEL. This operation and the next need to be chunked to make them feasible to perform within GPU memory.
chunks = torch.split(gen_cheater, 64, dim=-1)
@ -400,9 +395,9 @@ class MusicDiffusionFid(evaluator.Evaluator):
real_projections = []
for i in tqdm(list(range(0, len(self.data), self.skip))):
path = self.data[(i + self.env['rank']) % len(self.data)]
#audio = load_audio(path, 22050).to(self.dev)
audio = load_audio('C:\\Users\\James\\Music\\another_longer_sample.wav', 22050).to(self.dev) # <- hack, remove it!
audio = audio[:, :1764000]
audio = load_audio(path, 22050).to(self.dev)
#audio = load_audio('C:\\Users\\James\\Music\\another_longer_sample.wav', 22050).to(self.dev) # <- hack, remove it!
#audio = audio[:, :1764000]
if self.clip:
audio = audio[:, :100000]
sample, ref, sample_mel, ref_mel, sample_rate = self.diffusion_fn(audio)
@ -436,13 +431,14 @@ class MusicDiffusionFid(evaluator.Evaluator):
if __name__ == '__main__':
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_cheater_gen_r8.yml', 'generator',
also_load_savepoint=False,
load_path='X:\\dlas\\experiments\\train_music_cheater_gen_v5\\models\\203000_generator_ema.pth'
load_path='X:\\dlas\\experiments\\train_music_cheater_gen_v5_causal\\models\\1000_generator.pth'
).cuda()
opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :)
#'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety.
'diffusion_steps': 64,
'conditioning_free': True, 'conditioning_free_k': 1, 'use_ddim': True, 'clip_audio': False,
'diffusion_schedule': 'linear', 'diffusion_type': 'cheater_gen_fake_ar',
'diffusion_schedule': 'linear', 'diffusion_type': 'cheater_gen',
'causal': True, 'causal_slope': 4,
#'partial_low': 128, 'partial_high': 192
}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 232, 'device': 'cuda', 'opt': {}}