GD mods & fixes

- Report variational loss separately
- Report model prediction from injector
- Log these things
- Use respacing like guided diffusion
This commit is contained in:
James Betker 2021-06-04 17:13:16 -06:00
parent 6084915af8
commit bf811f80c1
4 changed files with 263 additions and 9 deletions

View File

@ -9,6 +9,7 @@ import enum
import math
import numpy as np
import torch
import torch as th
from tqdm import tqdm
@ -756,6 +757,7 @@ class GaussianDiffusion:
terms = {}
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
x_start_pred = torch.zeros_like(x_start) # This type of model doesn't predict x_start.
terms["loss"] = self._vb_terms_bpd(
model=model,
x_start=x_start,
@ -791,15 +793,22 @@ class GaussianDiffusion:
# Without a factor of 1/1000, the VB term hurts the MSE term.
terms["vb"] *= self.num_timesteps / 1000.0
target = {
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
if self.model_mean_type == ModelMeanType.PREVIOUS_X:
target = self.q_posterior_mean_variance(
x_start=x_start, x_t=x_t, t=t
)[0],
ModelMeanType.START_X: x_start,
ModelMeanType.EPSILON: noise,
}[self.model_mean_type]
)[0]
x_start_pred = torch.zeros(x_start) # Not supported.
elif self.model_mean_type == ModelMeanType.START_X:
target = x_start
x_start_pred = model_output
elif self.model_mean_type == ModelMeanType.EPSILON:
target = noise
x_start_pred = x_t - model_output
else:
raise NotImplementedError(self.model_mean_type)
assert model_output.shape == target.shape == x_start.shape
terms["mse"] = mean_flat((target - model_output) ** 2)
terms["x_start_predicted"] = x_start_pred
if "vb" in terms:
terms["loss"] = terms["mse"] + terms["vb"]
else:

View File

@ -0,0 +1,128 @@
import numpy as np
import torch as th
from .gaussian_diffusion import GaussianDiffusion
def space_timesteps(num_timesteps, section_counts):
"""
Create a list of timesteps to use from an original diffusion process,
given the number of timesteps we want to take from equally-sized portions
of the original process.
For example, if there's 300 timesteps and the section counts are [10,15,20]
then the first 100 timesteps are strided to be 10 timesteps, the second 100
are strided to be 15 timesteps, and the final 100 are strided to be 20.
If the stride is a string starting with "ddim", then the fixed striding
from the DDIM paper is used, and only one section is allowed.
:param num_timesteps: the number of diffusion steps in the original
process to divide up.
:param section_counts: either a list of numbers, or a string containing
comma-separated numbers, indicating the step count
per section. As a special case, use "ddimN" where N
is a number of steps to use the striding from the
DDIM paper.
:return: a set of diffusion steps from the original process to use.
"""
if isinstance(section_counts, str):
if section_counts.startswith("ddim"):
desired_count = int(section_counts[len("ddim") :])
for i in range(1, num_timesteps):
if len(range(0, num_timesteps, i)) == desired_count:
return set(range(0, num_timesteps, i))
raise ValueError(
f"cannot create exactly {num_timesteps} steps with an integer stride"
)
section_counts = [int(x) for x in section_counts.split(",")]
size_per = num_timesteps // len(section_counts)
extra = num_timesteps % len(section_counts)
start_idx = 0
all_steps = []
for i, section_count in enumerate(section_counts):
size = size_per + (1 if i < extra else 0)
if size < section_count:
raise ValueError(
f"cannot divide section of {size} steps into {section_count}"
)
if section_count <= 1:
frac_stride = 1
else:
frac_stride = (size - 1) / (section_count - 1)
cur_idx = 0.0
taken_steps = []
for _ in range(section_count):
taken_steps.append(start_idx + round(cur_idx))
cur_idx += frac_stride
all_steps += taken_steps
start_idx += size
return set(all_steps)
class SpacedDiffusion(GaussianDiffusion):
"""
A diffusion process which can skip steps in a base diffusion process.
:param use_timesteps: a collection (sequence or set) of timesteps from the
original diffusion process to retain.
:param kwargs: the kwargs to create the base diffusion process.
"""
def __init__(self, use_timesteps, **kwargs):
self.use_timesteps = set(use_timesteps)
self.timestep_map = []
self.original_num_steps = len(kwargs["betas"])
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
last_alpha_cumprod = 1.0
new_betas = []
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
if i in self.use_timesteps:
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
last_alpha_cumprod = alpha_cumprod
self.timestep_map.append(i)
kwargs["betas"] = np.array(new_betas)
super().__init__(**kwargs)
def p_mean_variance(
self, model, *args, **kwargs
): # pylint: disable=signature-differs
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
def training_losses(
self, model, *args, **kwargs
): # pylint: disable=signature-differs
return super().training_losses(self._wrap_model(model), *args, **kwargs)
def condition_mean(self, cond_fn, *args, **kwargs):
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
def condition_score(self, cond_fn, *args, **kwargs):
return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
def _wrap_model(self, model):
if isinstance(model, _WrappedModel):
return model
return _WrappedModel(
model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
)
def _scale_timesteps(self, t):
# Scaling is done by the wrapped model.
return t
class _WrappedModel:
def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
self.model = model
self.timestep_map = timestep_map
self.rescale_timesteps = rescale_timesteps
self.original_num_steps = original_num_steps
def __call__(self, x, ts, **kwargs):
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
new_ts = map_tensor[ts]
if self.rescale_timesteps:
new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
return self.model(x, new_ts, **kwargs)

View File

@ -2,6 +2,7 @@ import torch
from models.diffusion.gaussian_diffusion import GaussianDiffusion, get_named_beta_schedule
from models.diffusion.resample import create_named_schedule_sampler
from models.diffusion.respace import space_timesteps, SpacedDiffusion
from trainer.inject import Injector
from utils.util import opt_get
@ -12,8 +13,11 @@ class GaussianDiffusionInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
self.generator = opt['generator']
self.output_variational_bounds_key = opt['out_key_vb_loss']
self.output_x_start_key = opt['out_key_x_start']
opt['diffusion_args']['betas'] = get_named_beta_schedule(**opt['beta_schedule'])
self.diffusion = GaussianDiffusion(**opt['diffusion_args'])
opt['diffusion_args']['use_timesteps'] = space_timesteps(opt['beta_schedule']['num_diffusion_timesteps'], [opt['beta_schedule']['num_diffusion_timesteps']]) # TODO: Figure out how these work and specify them differently.
self.diffusion = SpacedDiffusion(**opt['diffusion_args'])
self.schedule_sampler = create_named_schedule_sampler(opt['sampler_type'], self.diffusion)
self.model_input_keys = opt_get(opt, ['model_input_keys'], [])
@ -22,7 +26,10 @@ class GaussianDiffusionInjector(Injector):
hq = state[self.input]
model_inputs = {k: state[v] for k, v in self.model_input_keys.items()}
t, weights = self.schedule_sampler.sample(hq.shape[0], hq.device)
return {self.output: self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs)['loss'] * weights}
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs)
return {self.output: diffusion_outputs['mse'],
self.output_variational_bounds_key: diffusion_outputs['vb'],
self.output_x_start_key: diffusion_outputs['x_start_predicted']}
# Performs inference using a network trained to predict a reverse diffusion process, which nets a image.
@ -32,7 +39,8 @@ class GaussianDiffusionInferenceInjector(Injector):
self.generator = opt['generator']
self.output_shape = opt['output_shape']
opt['diffusion_args']['betas'] = get_named_beta_schedule(**opt['beta_schedule'])
self.diffusion = GaussianDiffusion(**opt['diffusion_args'])
opt['diffusion_args']['use_timesteps'] = space_timesteps(opt['beta_schedule']['num_diffusion_timesteps'], [opt['beta_schedule']['num_diffusion_timesteps']]) # TODO: Figure out how these work and specify them differently.
self.diffusion = SpacedDiffusion(**opt['diffusion_args'])
self.model_input_keys = opt_get(opt, ['model_input_keys'], [])
def forward(self, state):

View File

@ -0,0 +1,109 @@
#### general settings
name: train_imgset_rrdb_diffusion
model: extensibletrainer
scale: 1
gpu_ids: [0]
start_step: -1
checkpointing_enabled: true
fp16: false
use_tb_logger: true
wandb: false
datasets:
train:
n_workers: 4
batch_size: 32
name: div2k
mode: single_image_extensible
paths: /content/div2k # <-- Put your path here.
target_size: 128
force_multiple: 1
scale: 4
num_corrupts_per_image: 0
networks:
generator:
type: generator
which_model_G: rrdb_diffusion
args:
in_channels: 6
out_channels: 6
num_blocks: 10
#### path
path:
#pretrain_model_generator: <insert pretrained model path if desired>
strict_load: true
#resume_state: ../experiments/train_imgset_rrdb_diffusion/training_state/0.state # <-- Set this to resume from a previous training state.
steps:
generator:
training: generator
optimizer_params:
lr: !!float 3e-4
weight_decay: !!float 1e-2
beta1: 0.9
beta2: 0.9999
injectors:
# "Do it all injector": produces a reverse prediction and calculates losses on it.
diffusion:
type: gaussian_diffusion
in: hq
generator: generator
beta_schedule:
schedule_name: linear
num_diffusion_timesteps: 4000
diffusion_args:
model_mean_type: epsilon
model_var_type: learned_range
loss_type: mse
sampler_type: uniform
model_input_keys:
low_res: lq
out: loss
# Injector for visualizing what your network is doing (every 500 steps)
visual_debug:
every: 500
type: gaussian_diffusion_inference
generator: generator
output_shape: [8,3,128,128] # Change "8" to your desired output batch size.
beta_schedule:
schedule_name: linear
num_diffusion_timesteps: 500 # Change higher (up to training steps) for improved quality. Lower for faster speed.
diffusion_args:
model_mean_type: epsilon
model_var_type: learned_range
loss_type: mse
model_input_keys:
low_res: lq
out: sample
losses:
diffusion_loss:
type: direct
weight: 1
key: loss
train:
niter: 500000
warmup_iter: -1
mega_batch_factor: 1 # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8].
val_freq: 4000
# Default LR scheduler options
default_lr_scheme: CosineAnnealingLR_Restart
T_period: [ 200000, 200000 ]
warmup: 0
eta_min: !!float 1e-7
restarts: [ 200000, 400000 ]
restart_weights: [ .5, .5 ]
logger:
print_freq: 30
save_checkpoint_freq: 2000
visuals: [sample, hq, lq]
visual_debug_rate: 500
reverse_n1_to_1: true