Implement conditioning-free diffusion at the eval level
This commit is contained in:
parent
436fe24822
commit
2134f06516
|
@ -123,11 +123,15 @@ class GaussianDiffusion:
|
|||
model_var_type,
|
||||
loss_type,
|
||||
rescale_timesteps=False,
|
||||
conditioning_free=False,
|
||||
conditioning_free_k=1,
|
||||
):
|
||||
self.model_mean_type = ModelMeanType(model_mean_type)
|
||||
self.model_var_type = ModelVarType(model_var_type)
|
||||
self.loss_type = LossType(loss_type)
|
||||
self.rescale_timesteps = rescale_timesteps
|
||||
self.conditioning_free = conditioning_free
|
||||
self.conditioning_free_k = conditioning_free_k
|
||||
|
||||
# Use float64 for accuracy.
|
||||
betas = np.array(betas, dtype=np.float64)
|
||||
|
@ -258,10 +262,14 @@ class GaussianDiffusion:
|
|||
B, C = x.shape[:2]
|
||||
assert t.shape == (B,)
|
||||
model_output = model(x, self._scale_timesteps(t), **model_kwargs)
|
||||
if self.conditioning_free:
|
||||
model_output_no_conditioning = model(x, self._scale_timesteps(t), conditioning_free=True, **model_kwargs)
|
||||
|
||||
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
|
||||
assert model_output.shape == (B, C * 2, *x.shape[2:])
|
||||
model_output, model_var_values = th.split(model_output, C, dim=1)
|
||||
if self.conditioning_free:
|
||||
model_output_no_conditioning, _ = th.split(model_output_no_conditioning, C, dim=1)
|
||||
if self.model_var_type == ModelVarType.LEARNED:
|
||||
model_log_variance = model_var_values
|
||||
model_variance = th.exp(model_log_variance)
|
||||
|
@ -290,6 +298,9 @@ class GaussianDiffusion:
|
|||
model_variance = _extract_into_tensor(model_variance, t, x.shape)
|
||||
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
|
||||
|
||||
if self.conditioning_free:
|
||||
model_output = (1 + self.conditioning_free_k) * model_output - self.conditioning_free_k * model_output_no_conditioning
|
||||
|
||||
def process_xstart(x):
|
||||
if denoised_fn is not None:
|
||||
x = denoised_fn(x)
|
||||
|
|
|
@ -52,12 +52,13 @@ def load_gpt_conditioning_inputs_from_directory(path, num_candidates=3, sample_r
|
|||
return torch.stack(related_mels, dim=0)
|
||||
|
||||
|
||||
def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, schedule='linear'):
|
||||
def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, schedule='linear', enable_conditioning_free_guidance=False, conditioning_free_k=1):
|
||||
"""
|
||||
Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
|
||||
"""
|
||||
return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon',
|
||||
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule(schedule, trained_diffusion_steps))
|
||||
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule(schedule, trained_diffusion_steps),
|
||||
conditioning_free=enable_conditioning_free_guidance, conditioning_free_k=conditioning_free_k)
|
||||
|
||||
|
||||
def do_spectrogram_diffusion(diffusion_model, dvae_model, diffuser, mel_codes, conditioning_input, spectrogram_compression_factor=128, plt_spec=False, mean=False):
|
||||
|
|
|
@ -38,7 +38,11 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
if diffusion_schedule is None:
|
||||
print("Unable to infer diffusion schedule from master options. Getting it from eval (or guessing).")
|
||||
diffusion_schedule = opt_get(opt_eval, ['diffusion_schedule'], 'cosine')
|
||||
self.diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_steps, schedule=diffusion_schedule)
|
||||
conditioning_free_diffusion_enabled = opt_get(opt_eval, ['conditioning_free'], False)
|
||||
conditioning_free_k = opt_get(opt_eval, ['conditioning_free_k'], 1)
|
||||
self.diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_steps, schedule=diffusion_schedule,
|
||||
conditioning_free_diffusion_enabled=conditioning_free_diffusion_enabled,
|
||||
conditioning_free_k=conditioning_free_k)
|
||||
self.dev = self.env['device']
|
||||
mode = opt_get(opt_eval, ['diffusion_type'], 'tts')
|
||||
if mode == 'tts':
|
||||
|
|
Loading…
Reference in New Issue
Block a user