Implement conditioning-free diffusion at the eval level
This commit is contained in:
parent
436fe24822
commit
2134f06516
|
@ -123,11 +123,15 @@ class GaussianDiffusion:
|
||||||
model_var_type,
|
model_var_type,
|
||||||
loss_type,
|
loss_type,
|
||||||
rescale_timesteps=False,
|
rescale_timesteps=False,
|
||||||
|
conditioning_free=False,
|
||||||
|
conditioning_free_k=1,
|
||||||
):
|
):
|
||||||
self.model_mean_type = ModelMeanType(model_mean_type)
|
self.model_mean_type = ModelMeanType(model_mean_type)
|
||||||
self.model_var_type = ModelVarType(model_var_type)
|
self.model_var_type = ModelVarType(model_var_type)
|
||||||
self.loss_type = LossType(loss_type)
|
self.loss_type = LossType(loss_type)
|
||||||
self.rescale_timesteps = rescale_timesteps
|
self.rescale_timesteps = rescale_timesteps
|
||||||
|
self.conditioning_free = conditioning_free
|
||||||
|
self.conditioning_free_k = conditioning_free_k
|
||||||
|
|
||||||
# Use float64 for accuracy.
|
# Use float64 for accuracy.
|
||||||
betas = np.array(betas, dtype=np.float64)
|
betas = np.array(betas, dtype=np.float64)
|
||||||
|
@ -258,10 +262,14 @@ class GaussianDiffusion:
|
||||||
B, C = x.shape[:2]
|
B, C = x.shape[:2]
|
||||||
assert t.shape == (B,)
|
assert t.shape == (B,)
|
||||||
model_output = model(x, self._scale_timesteps(t), **model_kwargs)
|
model_output = model(x, self._scale_timesteps(t), **model_kwargs)
|
||||||
|
if self.conditioning_free:
|
||||||
|
model_output_no_conditioning = model(x, self._scale_timesteps(t), conditioning_free=True, **model_kwargs)
|
||||||
|
|
||||||
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
|
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
|
||||||
assert model_output.shape == (B, C * 2, *x.shape[2:])
|
assert model_output.shape == (B, C * 2, *x.shape[2:])
|
||||||
model_output, model_var_values = th.split(model_output, C, dim=1)
|
model_output, model_var_values = th.split(model_output, C, dim=1)
|
||||||
|
if self.conditioning_free:
|
||||||
|
model_output_no_conditioning, _ = th.split(model_output_no_conditioning, C, dim=1)
|
||||||
if self.model_var_type == ModelVarType.LEARNED:
|
if self.model_var_type == ModelVarType.LEARNED:
|
||||||
model_log_variance = model_var_values
|
model_log_variance = model_var_values
|
||||||
model_variance = th.exp(model_log_variance)
|
model_variance = th.exp(model_log_variance)
|
||||||
|
@ -290,6 +298,9 @@ class GaussianDiffusion:
|
||||||
model_variance = _extract_into_tensor(model_variance, t, x.shape)
|
model_variance = _extract_into_tensor(model_variance, t, x.shape)
|
||||||
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
|
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
|
||||||
|
|
||||||
|
if self.conditioning_free:
|
||||||
|
model_output = (1 + self.conditioning_free_k) * model_output - self.conditioning_free_k * model_output_no_conditioning
|
||||||
|
|
||||||
def process_xstart(x):
|
def process_xstart(x):
|
||||||
if denoised_fn is not None:
|
if denoised_fn is not None:
|
||||||
x = denoised_fn(x)
|
x = denoised_fn(x)
|
||||||
|
|
|
@ -52,12 +52,13 @@ def load_gpt_conditioning_inputs_from_directory(path, num_candidates=3, sample_r
|
||||||
return torch.stack(related_mels, dim=0)
|
return torch.stack(related_mels, dim=0)
|
||||||
|
|
||||||
|
|
||||||
def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, schedule='linear'):
|
def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, schedule='linear', enable_conditioning_free_guidance=False, conditioning_free_k=1):
|
||||||
"""
|
"""
|
||||||
Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
|
Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
|
||||||
"""
|
"""
|
||||||
return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon',
|
return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon',
|
||||||
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule(schedule, trained_diffusion_steps))
|
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule(schedule, trained_diffusion_steps),
|
||||||
|
conditioning_free=enable_conditioning_free_guidance, conditioning_free_k=conditioning_free_k)
|
||||||
|
|
||||||
|
|
||||||
def do_spectrogram_diffusion(diffusion_model, dvae_model, diffuser, mel_codes, conditioning_input, spectrogram_compression_factor=128, plt_spec=False, mean=False):
|
def do_spectrogram_diffusion(diffusion_model, dvae_model, diffuser, mel_codes, conditioning_input, spectrogram_compression_factor=128, plt_spec=False, mean=False):
|
||||||
|
|
|
@ -38,7 +38,11 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
||||||
if diffusion_schedule is None:
|
if diffusion_schedule is None:
|
||||||
print("Unable to infer diffusion schedule from master options. Getting it from eval (or guessing).")
|
print("Unable to infer diffusion schedule from master options. Getting it from eval (or guessing).")
|
||||||
diffusion_schedule = opt_get(opt_eval, ['diffusion_schedule'], 'cosine')
|
diffusion_schedule = opt_get(opt_eval, ['diffusion_schedule'], 'cosine')
|
||||||
self.diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_steps, schedule=diffusion_schedule)
|
conditioning_free_diffusion_enabled = opt_get(opt_eval, ['conditioning_free'], False)
|
||||||
|
conditioning_free_k = opt_get(opt_eval, ['conditioning_free_k'], 1)
|
||||||
|
self.diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_steps, schedule=diffusion_schedule,
|
||||||
|
conditioning_free_diffusion_enabled=conditioning_free_diffusion_enabled,
|
||||||
|
conditioning_free_k=conditioning_free_k)
|
||||||
self.dev = self.env['device']
|
self.dev = self.env['device']
|
||||||
mode = opt_get(opt_eval, ['diffusion_type'], 'tts')
|
mode = opt_get(opt_eval, ['diffusion_type'], 'tts')
|
||||||
if mode == 'tts':
|
if mode == 'tts':
|
||||||
|
|
Loading…
Reference in New Issue
Block a user