Support diffusion models with extra return values & inference in diffusion_dvae
This commit is contained in:
parent
8d9857f33d
commit
6f48674647
|
@ -262,19 +262,8 @@ class DiffusionDVAE(nn.Module):
|
||||||
self.middle_block.apply(convert_module_to_f32)
|
self.middle_block.apply(convert_module_to_f32)
|
||||||
self.output_blocks.apply(convert_module_to_f32)
|
self.output_blocks.apply(convert_module_to_f32)
|
||||||
|
|
||||||
def forward(self, x, timesteps, spectrogram, conditioning_inputs=None):
|
def _decode_continouous(self, x, timesteps, embeddings, conditioning_inputs):
|
||||||
assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement.
|
spec_hs = self.decoder(embeddings)[::-1]
|
||||||
|
|
||||||
# Compute DVAE portion first.
|
|
||||||
spec_logits = self.encoder(spectrogram).permute((0,2,1))
|
|
||||||
sampled, commitment_loss, codes = self.quantizer(spec_logits)
|
|
||||||
if self.training:
|
|
||||||
# Compute from softmax outputs to preserve gradients.
|
|
||||||
sampled = sampled.permute((0,2,1))
|
|
||||||
else:
|
|
||||||
# Compute from codes only.
|
|
||||||
sampled = self.quantizer.embed_code(codes).permute((0,2,1))
|
|
||||||
spec_hs = self.decoder(sampled)[::-1]
|
|
||||||
# Shape the spectrogram correctly. There is no guarantee it fits (though I probably should add an assertion here to make sure the resizing isn't too wacky.)
|
# Shape the spectrogram correctly. There is no guarantee it fits (though I probably should add an assertion here to make sure the resizing isn't too wacky.)
|
||||||
spec_hs = [nn.functional.interpolate(sh, size=(x.shape[-1]//self.scale_steps**self.spectrogram_conditioning_levels[i],), mode='nearest') for i, sh in enumerate(spec_hs)]
|
spec_hs = [nn.functional.interpolate(sh, size=(x.shape[-1]//self.scale_steps**self.spectrogram_conditioning_levels[i],), mode='nearest') for i, sh in enumerate(spec_hs)]
|
||||||
convergence_fns = list(self.convergence_convs)
|
convergence_fns = list(self.convergence_convs)
|
||||||
|
@ -311,7 +300,26 @@ class DiffusionDVAE(nn.Module):
|
||||||
h = torch.cat([h, hs.pop()], dim=1)
|
h = torch.cat([h, hs.pop()], dim=1)
|
||||||
h = module(h, emb)
|
h = module(h, emb)
|
||||||
h = h.type(x.dtype)
|
h = h.type(x.dtype)
|
||||||
return self.out(h), commitment_loss
|
return self.out(h)
|
||||||
|
|
||||||
|
def decode(self, x, timesteps, codes, conditioning_inputs=None):
|
||||||
|
assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement.
|
||||||
|
embeddings = self.quantizer.embed_code(codes).permute((0,2,1))
|
||||||
|
return self._decode_continouous(x, timesteps, embeddings, conditioning_inputs), commitment_loss
|
||||||
|
|
||||||
|
def forward(self, x, timesteps, spectrogram, conditioning_inputs=None):
|
||||||
|
assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement.
|
||||||
|
|
||||||
|
# Compute DVAE portion first.
|
||||||
|
spec_logits = self.encoder(spectrogram).permute((0,2,1))
|
||||||
|
sampled, commitment_loss, codes = self.quantizer(spec_logits)
|
||||||
|
if self.training:
|
||||||
|
# Compute from softmax outputs to preserve gradients.
|
||||||
|
embeddings = sampled.permute((0,2,1))
|
||||||
|
else:
|
||||||
|
# Compute from codes only.
|
||||||
|
embeddings = self.quantizer.embed_code(codes).permute((0,2,1))
|
||||||
|
return self._decode_continouous(x, timesteps, embeddings, conditioning_inputs), commitment_loss
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
|
|
|
@ -757,6 +757,7 @@ class GaussianDiffusion:
|
||||||
terms = {}
|
terms = {}
|
||||||
|
|
||||||
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
||||||
|
# TODO: support multiple model outputs for this mode.
|
||||||
terms["loss"] = self._vb_terms_bpd(
|
terms["loss"] = self._vb_terms_bpd(
|
||||||
model=model,
|
model=model,
|
||||||
x_start=x_start,
|
x_start=x_start,
|
||||||
|
@ -768,7 +769,10 @@ class GaussianDiffusion:
|
||||||
if self.loss_type == LossType.RESCALED_KL:
|
if self.loss_type == LossType.RESCALED_KL:
|
||||||
terms["loss"] *= self.num_timesteps
|
terms["loss"] *= self.num_timesteps
|
||||||
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
|
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
|
||||||
model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)
|
model_outputs = model(x_t, self._scale_timesteps(t), **model_kwargs)
|
||||||
|
model_output = model_outputs[0]
|
||||||
|
if len(model_outputs) > 1:
|
||||||
|
terms['extra_outputs']: model_outputs[1:]
|
||||||
|
|
||||||
if self.model_var_type in [
|
if self.model_var_type in [
|
||||||
ModelVarType.LEARNED,
|
ModelVarType.LEARNED,
|
||||||
|
|
|
@ -21,6 +21,7 @@ class GaussianDiffusionInjector(Injector):
|
||||||
self.diffusion = SpacedDiffusion(**opt['diffusion_args'])
|
self.diffusion = SpacedDiffusion(**opt['diffusion_args'])
|
||||||
self.schedule_sampler = create_named_schedule_sampler(opt['sampler_type'], self.diffusion)
|
self.schedule_sampler = create_named_schedule_sampler(opt['sampler_type'], self.diffusion)
|
||||||
self.model_input_keys = opt_get(opt, ['model_input_keys'], [])
|
self.model_input_keys = opt_get(opt, ['model_input_keys'], [])
|
||||||
|
self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], [])
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
gen = self.env['generators'][self.opt['generator']]
|
gen = self.env['generators'][self.opt['generator']]
|
||||||
|
@ -30,9 +31,16 @@ class GaussianDiffusionInjector(Injector):
|
||||||
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs)
|
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs)
|
||||||
if isinstance(self.schedule_sampler, LossAwareSampler):
|
if isinstance(self.schedule_sampler, LossAwareSampler):
|
||||||
self.schedule_sampler.update_with_local_losses(t, diffusion_outputs['losses'])
|
self.schedule_sampler.update_with_local_losses(t, diffusion_outputs['losses'])
|
||||||
return {self.output: diffusion_outputs['mse'],
|
|
||||||
|
if len(self.extra_model_output_keys) > 0:
|
||||||
|
assert(len(self.extra_model_output_keys) == len(diffusion_outputs['extra_outputs']))
|
||||||
|
out = {k: v for k, v in zip(self.extra_model_output_keys, diffusion_outputs['extra_outputs'])}
|
||||||
|
else:
|
||||||
|
out = {}
|
||||||
|
out.update({self.output: diffusion_outputs['mse'],
|
||||||
self.output_variational_bounds_key: diffusion_outputs['vb'],
|
self.output_variational_bounds_key: diffusion_outputs['vb'],
|
||||||
self.output_x_start_key: diffusion_outputs['x_start_predicted']}
|
self.output_x_start_key: diffusion_outputs['x_start_predicted']})
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class AutoregressiveGaussianDiffusionInjector(Injector):
|
class AutoregressiveGaussianDiffusionInjector(Injector):
|
||||||
|
@ -67,7 +75,6 @@ class AutoregressiveGaussianDiffusionInjector(Injector):
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Performs inference using a network trained to predict a reverse diffusion process, which nets a image.
|
# Performs inference using a network trained to predict a reverse diffusion process, which nets a image.
|
||||||
class GaussianDiffusionInferenceInjector(Injector):
|
class GaussianDiffusionInferenceInjector(Injector):
|
||||||
def __init__(self, opt, env):
|
def __init__(self, opt, env):
|
||||||
|
@ -89,6 +96,9 @@ class GaussianDiffusionInferenceInjector(Injector):
|
||||||
self.use_ema_model = opt_get(opt, ['use_ema'], False)
|
self.use_ema_model = opt_get(opt, ['use_ema'], False)
|
||||||
self.noise_style = opt_get(opt, ['noise_type'], 'random') # 'zero', 'fixed' or 'random'
|
self.noise_style = opt_get(opt, ['noise_type'], 'random') # 'zero', 'fixed' or 'random'
|
||||||
|
|
||||||
|
self.model_fn = opt_get(opt, ['model_function'], None)
|
||||||
|
self.model_fn = None if self.model_fn is None else getattr(self.generator, self.model_fn)
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
if self.use_ema_model:
|
if self.use_ema_model:
|
||||||
gen = self.env['emas'][self.opt['generator']]
|
gen = self.env['emas'][self.opt['generator']]
|
||||||
|
@ -113,7 +123,7 @@ class GaussianDiffusionInferenceInjector(Injector):
|
||||||
if not hasattr(self, 'fixed_noise') or self.fixed_noise.shape != output_shape:
|
if not hasattr(self, 'fixed_noise') or self.fixed_noise.shape != output_shape:
|
||||||
self.fixed_noise = torch.randn(output_shape, device=dev)
|
self.fixed_noise = torch.randn(output_shape, device=dev)
|
||||||
noise = self.fixed_noise
|
noise = self.fixed_noise
|
||||||
gen = self.sampling_fn(gen, output_shape, noise=noise, model_kwargs=model_inputs, progress=True)
|
gen = self.sampling_fn(self.model_fn, output_shape, noise=noise, model_kwargs=model_inputs, progress=True, device=dev)
|
||||||
if self.undo_n1_to_1:
|
if self.undo_n1_to_1:
|
||||||
gen = (gen + 1) / 2
|
gen = (gen + 1) / 2
|
||||||
return {self.output: gen}
|
return {self.output: gen}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user