Implement conditioning-free diffusion at the eval level

This commit is contained in:
James Betker 2022-02-27 15:11:42 -07:00
parent 436fe24822
commit 2134f06516
3 changed files with 19 additions and 3 deletions

View File

@ -123,11 +123,15 @@ class GaussianDiffusion:
model_var_type, model_var_type,
loss_type, loss_type,
rescale_timesteps=False, rescale_timesteps=False,
conditioning_free=False,
conditioning_free_k=1,
): ):
self.model_mean_type = ModelMeanType(model_mean_type) self.model_mean_type = ModelMeanType(model_mean_type)
self.model_var_type = ModelVarType(model_var_type) self.model_var_type = ModelVarType(model_var_type)
self.loss_type = LossType(loss_type) self.loss_type = LossType(loss_type)
self.rescale_timesteps = rescale_timesteps self.rescale_timesteps = rescale_timesteps
self.conditioning_free = conditioning_free
self.conditioning_free_k = conditioning_free_k
# Use float64 for accuracy. # Use float64 for accuracy.
betas = np.array(betas, dtype=np.float64) betas = np.array(betas, dtype=np.float64)
@ -258,10 +262,14 @@ class GaussianDiffusion:
B, C = x.shape[:2] B, C = x.shape[:2]
assert t.shape == (B,) assert t.shape == (B,)
model_output = model(x, self._scale_timesteps(t), **model_kwargs) model_output = model(x, self._scale_timesteps(t), **model_kwargs)
if self.conditioning_free:
model_output_no_conditioning = model(x, self._scale_timesteps(t), conditioning_free=True, **model_kwargs)
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
assert model_output.shape == (B, C * 2, *x.shape[2:]) assert model_output.shape == (B, C * 2, *x.shape[2:])
model_output, model_var_values = th.split(model_output, C, dim=1) model_output, model_var_values = th.split(model_output, C, dim=1)
if self.conditioning_free:
model_output_no_conditioning, _ = th.split(model_output_no_conditioning, C, dim=1)
if self.model_var_type == ModelVarType.LEARNED: if self.model_var_type == ModelVarType.LEARNED:
model_log_variance = model_var_values model_log_variance = model_var_values
model_variance = th.exp(model_log_variance) model_variance = th.exp(model_log_variance)
@ -290,6 +298,9 @@ class GaussianDiffusion:
model_variance = _extract_into_tensor(model_variance, t, x.shape) model_variance = _extract_into_tensor(model_variance, t, x.shape)
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
if self.conditioning_free:
model_output = (1 + self.conditioning_free_k) * model_output - self.conditioning_free_k * model_output_no_conditioning
def process_xstart(x): def process_xstart(x):
if denoised_fn is not None: if denoised_fn is not None:
x = denoised_fn(x) x = denoised_fn(x)

View File

@ -52,12 +52,13 @@ def load_gpt_conditioning_inputs_from_directory(path, num_candidates=3, sample_r
return torch.stack(related_mels, dim=0) return torch.stack(related_mels, dim=0)
def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, schedule='linear'): def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, schedule='linear', enable_conditioning_free_guidance=False, conditioning_free_k=1):
""" """
Helper function to load a GaussianDiffusion instance configured for use as a vocoder. Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
""" """
return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon', return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon',
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule(schedule, trained_diffusion_steps)) model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule(schedule, trained_diffusion_steps),
conditioning_free=enable_conditioning_free_guidance, conditioning_free_k=conditioning_free_k)
def do_spectrogram_diffusion(diffusion_model, dvae_model, diffuser, mel_codes, conditioning_input, spectrogram_compression_factor=128, plt_spec=False, mean=False): def do_spectrogram_diffusion(diffusion_model, dvae_model, diffuser, mel_codes, conditioning_input, spectrogram_compression_factor=128, plt_spec=False, mean=False):

View File

@ -38,7 +38,11 @@ class AudioDiffusionFid(evaluator.Evaluator):
if diffusion_schedule is None: if diffusion_schedule is None:
print("Unable to infer diffusion schedule from master options. Getting it from eval (or guessing).") print("Unable to infer diffusion schedule from master options. Getting it from eval (or guessing).")
diffusion_schedule = opt_get(opt_eval, ['diffusion_schedule'], 'cosine') diffusion_schedule = opt_get(opt_eval, ['diffusion_schedule'], 'cosine')
self.diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_steps, schedule=diffusion_schedule) conditioning_free_diffusion_enabled = opt_get(opt_eval, ['conditioning_free'], False)
conditioning_free_k = opt_get(opt_eval, ['conditioning_free_k'], 1)
self.diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_steps, schedule=diffusion_schedule,
conditioning_free_diffusion_enabled=conditioning_free_diffusion_enabled,
conditioning_free_k=conditioning_free_k)
self.dev = self.env['device'] self.dev = self.env['device']
mode = opt_get(opt_eval, ['diffusion_type'], 'tts') mode = opt_get(opt_eval, ['diffusion_type'], 'tts')
if mode == 'tts': if mode == 'tts':