From 6f486746471a98789e9317203af51046fa976833 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 16 Sep 2021 10:53:46 -0600 Subject: [PATCH] Support diffusion models with extra return values & inference in diffusion_dvae --- codes/models/diffusion/diffusion_dvae.py | 36 +++++++++++-------- codes/models/diffusion/gaussian_diffusion.py | 6 +++- .../injectors/gaussian_diffusion_injector.py | 18 +++++++--- 3 files changed, 41 insertions(+), 19 deletions(-) diff --git a/codes/models/diffusion/diffusion_dvae.py b/codes/models/diffusion/diffusion_dvae.py index f489e241..a22c57fe 100644 --- a/codes/models/diffusion/diffusion_dvae.py +++ b/codes/models/diffusion/diffusion_dvae.py @@ -262,19 +262,8 @@ class DiffusionDVAE(nn.Module): self.middle_block.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32) - def forward(self, x, timesteps, spectrogram, conditioning_inputs=None): - assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement. - - # Compute DVAE portion first. - spec_logits = self.encoder(spectrogram).permute((0,2,1)) - sampled, commitment_loss, codes = self.quantizer(spec_logits) - if self.training: - # Compute from softmax outputs to preserve gradients. - sampled = sampled.permute((0,2,1)) - else: - # Compute from codes only. - sampled = self.quantizer.embed_code(codes).permute((0,2,1)) - spec_hs = self.decoder(sampled)[::-1] + def _decode_continouous(self, x, timesteps, embeddings, conditioning_inputs): + spec_hs = self.decoder(embeddings)[::-1] # Shape the spectrogram correctly. There is no guarantee it fits (though I probably should add an assertion here to make sure the resizing isn't too wacky.) spec_hs = [nn.functional.interpolate(sh, size=(x.shape[-1]//self.scale_steps**self.spectrogram_conditioning_levels[i],), mode='nearest') for i, sh in enumerate(spec_hs)] convergence_fns = list(self.convergence_convs) @@ -311,7 +300,26 @@ class DiffusionDVAE(nn.Module): h = torch.cat([h, hs.pop()], dim=1) h = module(h, emb) h = h.type(x.dtype) - return self.out(h), commitment_loss + return self.out(h) + + def decode(self, x, timesteps, codes, conditioning_inputs=None): + assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement. + embeddings = self.quantizer.embed_code(codes).permute((0,2,1)) + return self._decode_continouous(x, timesteps, embeddings, conditioning_inputs), commitment_loss + + def forward(self, x, timesteps, spectrogram, conditioning_inputs=None): + assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement. + + # Compute DVAE portion first. + spec_logits = self.encoder(spectrogram).permute((0,2,1)) + sampled, commitment_loss, codes = self.quantizer(spec_logits) + if self.training: + # Compute from softmax outputs to preserve gradients. + embeddings = sampled.permute((0,2,1)) + else: + # Compute from codes only. + embeddings = self.quantizer.embed_code(codes).permute((0,2,1)) + return self._decode_continouous(x, timesteps, embeddings, conditioning_inputs), commitment_loss @register_model diff --git a/codes/models/diffusion/gaussian_diffusion.py b/codes/models/diffusion/gaussian_diffusion.py index ac2f3d80..0980ae8a 100644 --- a/codes/models/diffusion/gaussian_diffusion.py +++ b/codes/models/diffusion/gaussian_diffusion.py @@ -757,6 +757,7 @@ class GaussianDiffusion: terms = {} if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + # TODO: support multiple model outputs for this mode. terms["loss"] = self._vb_terms_bpd( model=model, x_start=x_start, @@ -768,7 +769,10 @@ class GaussianDiffusion: if self.loss_type == LossType.RESCALED_KL: terms["loss"] *= self.num_timesteps elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: - model_output = model(x_t, self._scale_timesteps(t), **model_kwargs) + model_outputs = model(x_t, self._scale_timesteps(t), **model_kwargs) + model_output = model_outputs[0] + if len(model_outputs) > 1: + terms['extra_outputs']: model_outputs[1:] if self.model_var_type in [ ModelVarType.LEARNED, diff --git a/codes/trainer/injectors/gaussian_diffusion_injector.py b/codes/trainer/injectors/gaussian_diffusion_injector.py index f9fa05ca..d63d22e0 100644 --- a/codes/trainer/injectors/gaussian_diffusion_injector.py +++ b/codes/trainer/injectors/gaussian_diffusion_injector.py @@ -21,6 +21,7 @@ class GaussianDiffusionInjector(Injector): self.diffusion = SpacedDiffusion(**opt['diffusion_args']) self.schedule_sampler = create_named_schedule_sampler(opt['sampler_type'], self.diffusion) self.model_input_keys = opt_get(opt, ['model_input_keys'], []) + self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], []) def forward(self, state): gen = self.env['generators'][self.opt['generator']] @@ -30,9 +31,16 @@ class GaussianDiffusionInjector(Injector): diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs) if isinstance(self.schedule_sampler, LossAwareSampler): self.schedule_sampler.update_with_local_losses(t, diffusion_outputs['losses']) - return {self.output: diffusion_outputs['mse'], + + if len(self.extra_model_output_keys) > 0: + assert(len(self.extra_model_output_keys) == len(diffusion_outputs['extra_outputs'])) + out = {k: v for k, v in zip(self.extra_model_output_keys, diffusion_outputs['extra_outputs'])} + else: + out = {} + out.update({self.output: diffusion_outputs['mse'], self.output_variational_bounds_key: diffusion_outputs['vb'], - self.output_x_start_key: diffusion_outputs['x_start_predicted']} + self.output_x_start_key: diffusion_outputs['x_start_predicted']}) + return out class AutoregressiveGaussianDiffusionInjector(Injector): @@ -67,7 +75,6 @@ class AutoregressiveGaussianDiffusionInjector(Injector): return outputs - # Performs inference using a network trained to predict a reverse diffusion process, which nets a image. class GaussianDiffusionInferenceInjector(Injector): def __init__(self, opt, env): @@ -89,6 +96,9 @@ class GaussianDiffusionInferenceInjector(Injector): self.use_ema_model = opt_get(opt, ['use_ema'], False) self.noise_style = opt_get(opt, ['noise_type'], 'random') # 'zero', 'fixed' or 'random' + self.model_fn = opt_get(opt, ['model_function'], None) + self.model_fn = None if self.model_fn is None else getattr(self.generator, self.model_fn) + def forward(self, state): if self.use_ema_model: gen = self.env['emas'][self.opt['generator']] @@ -113,7 +123,7 @@ class GaussianDiffusionInferenceInjector(Injector): if not hasattr(self, 'fixed_noise') or self.fixed_noise.shape != output_shape: self.fixed_noise = torch.randn(output_shape, device=dev) noise = self.fixed_noise - gen = self.sampling_fn(gen, output_shape, noise=noise, model_kwargs=model_inputs, progress=True) + gen = self.sampling_fn(self.model_fn, output_shape, noise=noise, model_kwargs=model_inputs, progress=True, device=dev) if self.undo_n1_to_1: gen = (gen + 1) / 2 return {self.output: gen}