diff --git a/codes/models/diffusion/gaussian_diffusion.py b/codes/models/diffusion/gaussian_diffusion.py index ccb4e46d..61a2d25b 100644 --- a/codes/models/diffusion/gaussian_diffusion.py +++ b/codes/models/diffusion/gaussian_diffusion.py @@ -9,6 +9,7 @@ import enum import math import numpy as np +import torch import torch as th from tqdm import tqdm @@ -756,6 +757,7 @@ class GaussianDiffusion: terms = {} if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + x_start_pred = torch.zeros_like(x_start) # This type of model doesn't predict x_start. terms["loss"] = self._vb_terms_bpd( model=model, x_start=x_start, @@ -791,15 +793,22 @@ class GaussianDiffusion: # Without a factor of 1/1000, the VB term hurts the MSE term. terms["vb"] *= self.num_timesteps / 1000.0 - target = { - ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( + if self.model_mean_type == ModelMeanType.PREVIOUS_X: + target = self.q_posterior_mean_variance( x_start=x_start, x_t=x_t, t=t - )[0], - ModelMeanType.START_X: x_start, - ModelMeanType.EPSILON: noise, - }[self.model_mean_type] + )[0] + x_start_pred = torch.zeros(x_start) # Not supported. + elif self.model_mean_type == ModelMeanType.START_X: + target = x_start + x_start_pred = model_output + elif self.model_mean_type == ModelMeanType.EPSILON: + target = noise + x_start_pred = x_t - model_output + else: + raise NotImplementedError(self.model_mean_type) assert model_output.shape == target.shape == x_start.shape terms["mse"] = mean_flat((target - model_output) ** 2) + terms["x_start_predicted"] = x_start_pred if "vb" in terms: terms["loss"] = terms["mse"] + terms["vb"] else: diff --git a/codes/models/diffusion/respace.py b/codes/models/diffusion/respace.py new file mode 100644 index 00000000..b568817e --- /dev/null +++ b/codes/models/diffusion/respace.py @@ -0,0 +1,128 @@ +import numpy as np +import torch as th + +from .gaussian_diffusion import GaussianDiffusion + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}" + ) + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def training_losses( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel( + model, self.timestep_map, self.rescale_timesteps, self.original_num_steps + ) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): + self.model = model + self.timestep_map = timestep_map + self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + if self.rescale_timesteps: + new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) diff --git a/codes/trainer/injectors/gaussian_diffusion_injector.py b/codes/trainer/injectors/gaussian_diffusion_injector.py index 2c9f7d35..00636497 100644 --- a/codes/trainer/injectors/gaussian_diffusion_injector.py +++ b/codes/trainer/injectors/gaussian_diffusion_injector.py @@ -2,6 +2,7 @@ import torch from models.diffusion.gaussian_diffusion import GaussianDiffusion, get_named_beta_schedule from models.diffusion.resample import create_named_schedule_sampler +from models.diffusion.respace import space_timesteps, SpacedDiffusion from trainer.inject import Injector from utils.util import opt_get @@ -12,8 +13,11 @@ class GaussianDiffusionInjector(Injector): def __init__(self, opt, env): super().__init__(opt, env) self.generator = opt['generator'] + self.output_variational_bounds_key = opt['out_key_vb_loss'] + self.output_x_start_key = opt['out_key_x_start'] opt['diffusion_args']['betas'] = get_named_beta_schedule(**opt['beta_schedule']) - self.diffusion = GaussianDiffusion(**opt['diffusion_args']) + opt['diffusion_args']['use_timesteps'] = space_timesteps(opt['beta_schedule']['num_diffusion_timesteps'], [opt['beta_schedule']['num_diffusion_timesteps']]) # TODO: Figure out how these work and specify them differently. + self.diffusion = SpacedDiffusion(**opt['diffusion_args']) self.schedule_sampler = create_named_schedule_sampler(opt['sampler_type'], self.diffusion) self.model_input_keys = opt_get(opt, ['model_input_keys'], []) @@ -22,7 +26,10 @@ class GaussianDiffusionInjector(Injector): hq = state[self.input] model_inputs = {k: state[v] for k, v in self.model_input_keys.items()} t, weights = self.schedule_sampler.sample(hq.shape[0], hq.device) - return {self.output: self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs)['loss'] * weights} + diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs) + return {self.output: diffusion_outputs['mse'], + self.output_variational_bounds_key: diffusion_outputs['vb'], + self.output_x_start_key: diffusion_outputs['x_start_predicted']} # Performs inference using a network trained to predict a reverse diffusion process, which nets a image. @@ -32,7 +39,8 @@ class GaussianDiffusionInferenceInjector(Injector): self.generator = opt['generator'] self.output_shape = opt['output_shape'] opt['diffusion_args']['betas'] = get_named_beta_schedule(**opt['beta_schedule']) - self.diffusion = GaussianDiffusion(**opt['diffusion_args']) + opt['diffusion_args']['use_timesteps'] = space_timesteps(opt['beta_schedule']['num_diffusion_timesteps'], [opt['beta_schedule']['num_diffusion_timesteps']]) # TODO: Figure out how these work and specify them differently. + self.diffusion = SpacedDiffusion(**opt['diffusion_args']) self.model_input_keys = opt_get(opt, ['model_input_keys'], []) def forward(self, state): diff --git a/recipes/ddpm/train_ddpm_rrdb.yml b/recipes/ddpm/train_ddpm_rrdb.yml new file mode 100644 index 00000000..2416f8f6 --- /dev/null +++ b/recipes/ddpm/train_ddpm_rrdb.yml @@ -0,0 +1,109 @@ +#### general settings +name: train_imgset_rrdb_diffusion +model: extensibletrainer +scale: 1 +gpu_ids: [0] +start_step: -1 +checkpointing_enabled: true +fp16: false +use_tb_logger: true +wandb: false + +datasets: + train: + n_workers: 4 + batch_size: 32 + name: div2k + mode: single_image_extensible + paths: /content/div2k # <-- Put your path here. + target_size: 128 + force_multiple: 1 + scale: 4 + num_corrupts_per_image: 0 + +networks: + generator: + type: generator + which_model_G: rrdb_diffusion + args: + in_channels: 6 + out_channels: 6 + num_blocks: 10 + +#### path +path: + #pretrain_model_generator: + strict_load: true + #resume_state: ../experiments/train_imgset_rrdb_diffusion/training_state/0.state # <-- Set this to resume from a previous training state. + +steps: + generator: + training: generator + + optimizer_params: + lr: !!float 3e-4 + weight_decay: !!float 1e-2 + beta1: 0.9 + beta2: 0.9999 + + injectors: + # "Do it all injector": produces a reverse prediction and calculates losses on it. + diffusion: + type: gaussian_diffusion + in: hq + generator: generator + beta_schedule: + schedule_name: linear + num_diffusion_timesteps: 4000 + diffusion_args: + model_mean_type: epsilon + model_var_type: learned_range + loss_type: mse + sampler_type: uniform + model_input_keys: + low_res: lq + out: loss + + # Injector for visualizing what your network is doing (every 500 steps) + visual_debug: + every: 500 + type: gaussian_diffusion_inference + generator: generator + output_shape: [8,3,128,128] # Change "8" to your desired output batch size. + beta_schedule: + schedule_name: linear + num_diffusion_timesteps: 500 # Change higher (up to training steps) for improved quality. Lower for faster speed. + diffusion_args: + model_mean_type: epsilon + model_var_type: learned_range + loss_type: mse + model_input_keys: + low_res: lq + out: sample + + losses: + diffusion_loss: + type: direct + weight: 1 + key: loss + +train: + niter: 500000 + warmup_iter: -1 + mega_batch_factor: 1 # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8]. + val_freq: 4000 + + # Default LR scheduler options + default_lr_scheme: CosineAnnealingLR_Restart + T_period: [ 200000, 200000 ] + warmup: 0 + eta_min: !!float 1e-7 + restarts: [ 200000, 400000 ] + restart_weights: [ .5, .5 ] + +logger: + print_freq: 30 + save_checkpoint_freq: 2000 + visuals: [sample, hq, lq] + visual_debug_rate: 500 + reverse_n1_to_1: true \ No newline at end of file