From 2134f06516e66f5e6ea7615deefccd0cff070188 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 27 Feb 2022 15:11:42 -0700 Subject: [PATCH] Implement conditioning-free diffusion at the eval level --- codes/models/diffusion/gaussian_diffusion.py | 11 +++++++++++ codes/scripts/audio/gen/speech_synthesis_utils.py | 5 +++-- codes/trainer/eval/audio_diffusion_fid.py | 6 +++++- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/codes/models/diffusion/gaussian_diffusion.py b/codes/models/diffusion/gaussian_diffusion.py index fafeaef3..21700fc1 100644 --- a/codes/models/diffusion/gaussian_diffusion.py +++ b/codes/models/diffusion/gaussian_diffusion.py @@ -123,11 +123,15 @@ class GaussianDiffusion: model_var_type, loss_type, rescale_timesteps=False, + conditioning_free=False, + conditioning_free_k=1, ): self.model_mean_type = ModelMeanType(model_mean_type) self.model_var_type = ModelVarType(model_var_type) self.loss_type = LossType(loss_type) self.rescale_timesteps = rescale_timesteps + self.conditioning_free = conditioning_free + self.conditioning_free_k = conditioning_free_k # Use float64 for accuracy. betas = np.array(betas, dtype=np.float64) @@ -258,10 +262,14 @@ class GaussianDiffusion: B, C = x.shape[:2] assert t.shape == (B,) model_output = model(x, self._scale_timesteps(t), **model_kwargs) + if self.conditioning_free: + model_output_no_conditioning = model(x, self._scale_timesteps(t), conditioning_free=True, **model_kwargs) if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: assert model_output.shape == (B, C * 2, *x.shape[2:]) model_output, model_var_values = th.split(model_output, C, dim=1) + if self.conditioning_free: + model_output_no_conditioning, _ = th.split(model_output_no_conditioning, C, dim=1) if self.model_var_type == ModelVarType.LEARNED: model_log_variance = model_var_values model_variance = th.exp(model_log_variance) @@ -290,6 +298,9 @@ class GaussianDiffusion: model_variance = _extract_into_tensor(model_variance, t, x.shape) model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + if self.conditioning_free: + model_output = (1 + self.conditioning_free_k) * model_output - self.conditioning_free_k * model_output_no_conditioning + def process_xstart(x): if denoised_fn is not None: x = denoised_fn(x) diff --git a/codes/scripts/audio/gen/speech_synthesis_utils.py b/codes/scripts/audio/gen/speech_synthesis_utils.py index 62fee92c..45fc0079 100644 --- a/codes/scripts/audio/gen/speech_synthesis_utils.py +++ b/codes/scripts/audio/gen/speech_synthesis_utils.py @@ -52,12 +52,13 @@ def load_gpt_conditioning_inputs_from_directory(path, num_candidates=3, sample_r return torch.stack(related_mels, dim=0) -def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, schedule='linear'): +def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, schedule='linear', enable_conditioning_free_guidance=False, conditioning_free_k=1): """ Helper function to load a GaussianDiffusion instance configured for use as a vocoder. """ return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon', - model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule(schedule, trained_diffusion_steps)) + model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule(schedule, trained_diffusion_steps), + conditioning_free=enable_conditioning_free_guidance, conditioning_free_k=conditioning_free_k) def do_spectrogram_diffusion(diffusion_model, dvae_model, diffuser, mel_codes, conditioning_input, spectrogram_compression_factor=128, plt_spec=False, mean=False): diff --git a/codes/trainer/eval/audio_diffusion_fid.py b/codes/trainer/eval/audio_diffusion_fid.py index 5f8e892c..f584c0f1 100644 --- a/codes/trainer/eval/audio_diffusion_fid.py +++ b/codes/trainer/eval/audio_diffusion_fid.py @@ -38,7 +38,11 @@ class AudioDiffusionFid(evaluator.Evaluator): if diffusion_schedule is None: print("Unable to infer diffusion schedule from master options. Getting it from eval (or guessing).") diffusion_schedule = opt_get(opt_eval, ['diffusion_schedule'], 'cosine') - self.diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_steps, schedule=diffusion_schedule) + conditioning_free_diffusion_enabled = opt_get(opt_eval, ['conditioning_free'], False) + conditioning_free_k = opt_get(opt_eval, ['conditioning_free_k'], 1) + self.diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_steps, schedule=diffusion_schedule, + conditioning_free_diffusion_enabled=conditioning_free_diffusion_enabled, + conditioning_free_k=conditioning_free_k) self.dev = self.env['device'] mode = opt_get(opt_eval, ['diffusion_type'], 'tts') if mode == 'tts':