Support diffusion models with extra return values & inference in diffusion_dvae

This commit is contained in:
James Betker 2021-09-16 10:53:46 -06:00
parent 8d9857f33d
commit 6f48674647
3 changed files with 41 additions and 19 deletions

View File

@ -262,19 +262,8 @@ class DiffusionDVAE(nn.Module):
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps, spectrogram, conditioning_inputs=None):
assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement.
# Compute DVAE portion first.
spec_logits = self.encoder(spectrogram).permute((0,2,1))
sampled, commitment_loss, codes = self.quantizer(spec_logits)
if self.training:
# Compute from softmax outputs to preserve gradients.
sampled = sampled.permute((0,2,1))
else:
# Compute from codes only.
sampled = self.quantizer.embed_code(codes).permute((0,2,1))
spec_hs = self.decoder(sampled)[::-1]
def _decode_continouous(self, x, timesteps, embeddings, conditioning_inputs):
spec_hs = self.decoder(embeddings)[::-1]
# Shape the spectrogram correctly. There is no guarantee it fits (though I probably should add an assertion here to make sure the resizing isn't too wacky.)
spec_hs = [nn.functional.interpolate(sh, size=(x.shape[-1]//self.scale_steps**self.spectrogram_conditioning_levels[i],), mode='nearest') for i, sh in enumerate(spec_hs)]
convergence_fns = list(self.convergence_convs)
@ -311,7 +300,26 @@ class DiffusionDVAE(nn.Module):
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb)
h = h.type(x.dtype)
return self.out(h), commitment_loss
return self.out(h)
def decode(self, x, timesteps, codes, conditioning_inputs=None):
assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement.
embeddings = self.quantizer.embed_code(codes).permute((0,2,1))
return self._decode_continouous(x, timesteps, embeddings, conditioning_inputs), commitment_loss
def forward(self, x, timesteps, spectrogram, conditioning_inputs=None):
assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement.
# Compute DVAE portion first.
spec_logits = self.encoder(spectrogram).permute((0,2,1))
sampled, commitment_loss, codes = self.quantizer(spec_logits)
if self.training:
# Compute from softmax outputs to preserve gradients.
embeddings = sampled.permute((0,2,1))
else:
# Compute from codes only.
embeddings = self.quantizer.embed_code(codes).permute((0,2,1))
return self._decode_continouous(x, timesteps, embeddings, conditioning_inputs), commitment_loss
@register_model

View File

@ -757,6 +757,7 @@ class GaussianDiffusion:
terms = {}
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
# TODO: support multiple model outputs for this mode.
terms["loss"] = self._vb_terms_bpd(
model=model,
x_start=x_start,
@ -768,7 +769,10 @@ class GaussianDiffusion:
if self.loss_type == LossType.RESCALED_KL:
terms["loss"] *= self.num_timesteps
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)
model_outputs = model(x_t, self._scale_timesteps(t), **model_kwargs)
model_output = model_outputs[0]
if len(model_outputs) > 1:
terms['extra_outputs']: model_outputs[1:]
if self.model_var_type in [
ModelVarType.LEARNED,

View File

@ -21,6 +21,7 @@ class GaussianDiffusionInjector(Injector):
self.diffusion = SpacedDiffusion(**opt['diffusion_args'])
self.schedule_sampler = create_named_schedule_sampler(opt['sampler_type'], self.diffusion)
self.model_input_keys = opt_get(opt, ['model_input_keys'], [])
self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], [])
def forward(self, state):
gen = self.env['generators'][self.opt['generator']]
@ -30,9 +31,16 @@ class GaussianDiffusionInjector(Injector):
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs)
if isinstance(self.schedule_sampler, LossAwareSampler):
self.schedule_sampler.update_with_local_losses(t, diffusion_outputs['losses'])
return {self.output: diffusion_outputs['mse'],
if len(self.extra_model_output_keys) > 0:
assert(len(self.extra_model_output_keys) == len(diffusion_outputs['extra_outputs']))
out = {k: v for k, v in zip(self.extra_model_output_keys, diffusion_outputs['extra_outputs'])}
else:
out = {}
out.update({self.output: diffusion_outputs['mse'],
self.output_variational_bounds_key: diffusion_outputs['vb'],
self.output_x_start_key: diffusion_outputs['x_start_predicted']}
self.output_x_start_key: diffusion_outputs['x_start_predicted']})
return out
class AutoregressiveGaussianDiffusionInjector(Injector):
@ -67,7 +75,6 @@ class AutoregressiveGaussianDiffusionInjector(Injector):
return outputs
# Performs inference using a network trained to predict a reverse diffusion process, which nets a image.
class GaussianDiffusionInferenceInjector(Injector):
def __init__(self, opt, env):
@ -89,6 +96,9 @@ class GaussianDiffusionInferenceInjector(Injector):
self.use_ema_model = opt_get(opt, ['use_ema'], False)
self.noise_style = opt_get(opt, ['noise_type'], 'random') # 'zero', 'fixed' or 'random'
self.model_fn = opt_get(opt, ['model_function'], None)
self.model_fn = None if self.model_fn is None else getattr(self.generator, self.model_fn)
def forward(self, state):
if self.use_ema_model:
gen = self.env['emas'][self.opt['generator']]
@ -113,7 +123,7 @@ class GaussianDiffusionInferenceInjector(Injector):
if not hasattr(self, 'fixed_noise') or self.fixed_noise.shape != output_shape:
self.fixed_noise = torch.randn(output_shape, device=dev)
noise = self.fixed_noise
gen = self.sampling_fn(gen, output_shape, noise=noise, model_kwargs=model_inputs, progress=True)
gen = self.sampling_fn(self.model_fn, output_shape, noise=noise, model_kwargs=model_inputs, progress=True, device=dev)
if self.undo_n1_to_1:
gen = (gen + 1) / 2
return {self.output: gen}