forked from mrq/DL-Art-School
GD mods & fixes
- Report variational loss separately - Report model prediction from injector - Log these things - Use respacing like guided diffusion
This commit is contained in:
parent
6084915af8
commit
bf811f80c1
|
@ -9,6 +9,7 @@ import enum
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
import torch as th
|
import torch as th
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
@ -756,6 +757,7 @@ class GaussianDiffusion:
|
||||||
terms = {}
|
terms = {}
|
||||||
|
|
||||||
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
||||||
|
x_start_pred = torch.zeros_like(x_start) # This type of model doesn't predict x_start.
|
||||||
terms["loss"] = self._vb_terms_bpd(
|
terms["loss"] = self._vb_terms_bpd(
|
||||||
model=model,
|
model=model,
|
||||||
x_start=x_start,
|
x_start=x_start,
|
||||||
|
@ -791,15 +793,22 @@ class GaussianDiffusion:
|
||||||
# Without a factor of 1/1000, the VB term hurts the MSE term.
|
# Without a factor of 1/1000, the VB term hurts the MSE term.
|
||||||
terms["vb"] *= self.num_timesteps / 1000.0
|
terms["vb"] *= self.num_timesteps / 1000.0
|
||||||
|
|
||||||
target = {
|
if self.model_mean_type == ModelMeanType.PREVIOUS_X:
|
||||||
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
|
target = self.q_posterior_mean_variance(
|
||||||
x_start=x_start, x_t=x_t, t=t
|
x_start=x_start, x_t=x_t, t=t
|
||||||
)[0],
|
)[0]
|
||||||
ModelMeanType.START_X: x_start,
|
x_start_pred = torch.zeros(x_start) # Not supported.
|
||||||
ModelMeanType.EPSILON: noise,
|
elif self.model_mean_type == ModelMeanType.START_X:
|
||||||
}[self.model_mean_type]
|
target = x_start
|
||||||
|
x_start_pred = model_output
|
||||||
|
elif self.model_mean_type == ModelMeanType.EPSILON:
|
||||||
|
target = noise
|
||||||
|
x_start_pred = x_t - model_output
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(self.model_mean_type)
|
||||||
assert model_output.shape == target.shape == x_start.shape
|
assert model_output.shape == target.shape == x_start.shape
|
||||||
terms["mse"] = mean_flat((target - model_output) ** 2)
|
terms["mse"] = mean_flat((target - model_output) ** 2)
|
||||||
|
terms["x_start_predicted"] = x_start_pred
|
||||||
if "vb" in terms:
|
if "vb" in terms:
|
||||||
terms["loss"] = terms["mse"] + terms["vb"]
|
terms["loss"] = terms["mse"] + terms["vb"]
|
||||||
else:
|
else:
|
||||||
|
|
128
codes/models/diffusion/respace.py
Normal file
128
codes/models/diffusion/respace.py
Normal file
|
@ -0,0 +1,128 @@
|
||||||
|
import numpy as np
|
||||||
|
import torch as th
|
||||||
|
|
||||||
|
from .gaussian_diffusion import GaussianDiffusion
|
||||||
|
|
||||||
|
|
||||||
|
def space_timesteps(num_timesteps, section_counts):
|
||||||
|
"""
|
||||||
|
Create a list of timesteps to use from an original diffusion process,
|
||||||
|
given the number of timesteps we want to take from equally-sized portions
|
||||||
|
of the original process.
|
||||||
|
|
||||||
|
For example, if there's 300 timesteps and the section counts are [10,15,20]
|
||||||
|
then the first 100 timesteps are strided to be 10 timesteps, the second 100
|
||||||
|
are strided to be 15 timesteps, and the final 100 are strided to be 20.
|
||||||
|
|
||||||
|
If the stride is a string starting with "ddim", then the fixed striding
|
||||||
|
from the DDIM paper is used, and only one section is allowed.
|
||||||
|
|
||||||
|
:param num_timesteps: the number of diffusion steps in the original
|
||||||
|
process to divide up.
|
||||||
|
:param section_counts: either a list of numbers, or a string containing
|
||||||
|
comma-separated numbers, indicating the step count
|
||||||
|
per section. As a special case, use "ddimN" where N
|
||||||
|
is a number of steps to use the striding from the
|
||||||
|
DDIM paper.
|
||||||
|
:return: a set of diffusion steps from the original process to use.
|
||||||
|
"""
|
||||||
|
if isinstance(section_counts, str):
|
||||||
|
if section_counts.startswith("ddim"):
|
||||||
|
desired_count = int(section_counts[len("ddim") :])
|
||||||
|
for i in range(1, num_timesteps):
|
||||||
|
if len(range(0, num_timesteps, i)) == desired_count:
|
||||||
|
return set(range(0, num_timesteps, i))
|
||||||
|
raise ValueError(
|
||||||
|
f"cannot create exactly {num_timesteps} steps with an integer stride"
|
||||||
|
)
|
||||||
|
section_counts = [int(x) for x in section_counts.split(",")]
|
||||||
|
size_per = num_timesteps // len(section_counts)
|
||||||
|
extra = num_timesteps % len(section_counts)
|
||||||
|
start_idx = 0
|
||||||
|
all_steps = []
|
||||||
|
for i, section_count in enumerate(section_counts):
|
||||||
|
size = size_per + (1 if i < extra else 0)
|
||||||
|
if size < section_count:
|
||||||
|
raise ValueError(
|
||||||
|
f"cannot divide section of {size} steps into {section_count}"
|
||||||
|
)
|
||||||
|
if section_count <= 1:
|
||||||
|
frac_stride = 1
|
||||||
|
else:
|
||||||
|
frac_stride = (size - 1) / (section_count - 1)
|
||||||
|
cur_idx = 0.0
|
||||||
|
taken_steps = []
|
||||||
|
for _ in range(section_count):
|
||||||
|
taken_steps.append(start_idx + round(cur_idx))
|
||||||
|
cur_idx += frac_stride
|
||||||
|
all_steps += taken_steps
|
||||||
|
start_idx += size
|
||||||
|
return set(all_steps)
|
||||||
|
|
||||||
|
|
||||||
|
class SpacedDiffusion(GaussianDiffusion):
|
||||||
|
"""
|
||||||
|
A diffusion process which can skip steps in a base diffusion process.
|
||||||
|
|
||||||
|
:param use_timesteps: a collection (sequence or set) of timesteps from the
|
||||||
|
original diffusion process to retain.
|
||||||
|
:param kwargs: the kwargs to create the base diffusion process.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, use_timesteps, **kwargs):
|
||||||
|
self.use_timesteps = set(use_timesteps)
|
||||||
|
self.timestep_map = []
|
||||||
|
self.original_num_steps = len(kwargs["betas"])
|
||||||
|
|
||||||
|
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
|
||||||
|
last_alpha_cumprod = 1.0
|
||||||
|
new_betas = []
|
||||||
|
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
|
||||||
|
if i in self.use_timesteps:
|
||||||
|
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
|
||||||
|
last_alpha_cumprod = alpha_cumprod
|
||||||
|
self.timestep_map.append(i)
|
||||||
|
kwargs["betas"] = np.array(new_betas)
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def p_mean_variance(
|
||||||
|
self, model, *args, **kwargs
|
||||||
|
): # pylint: disable=signature-differs
|
||||||
|
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
|
||||||
|
|
||||||
|
def training_losses(
|
||||||
|
self, model, *args, **kwargs
|
||||||
|
): # pylint: disable=signature-differs
|
||||||
|
return super().training_losses(self._wrap_model(model), *args, **kwargs)
|
||||||
|
|
||||||
|
def condition_mean(self, cond_fn, *args, **kwargs):
|
||||||
|
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
|
||||||
|
|
||||||
|
def condition_score(self, cond_fn, *args, **kwargs):
|
||||||
|
return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
|
||||||
|
|
||||||
|
def _wrap_model(self, model):
|
||||||
|
if isinstance(model, _WrappedModel):
|
||||||
|
return model
|
||||||
|
return _WrappedModel(
|
||||||
|
model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
|
||||||
|
)
|
||||||
|
|
||||||
|
def _scale_timesteps(self, t):
|
||||||
|
# Scaling is done by the wrapped model.
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
class _WrappedModel:
|
||||||
|
def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
|
||||||
|
self.model = model
|
||||||
|
self.timestep_map = timestep_map
|
||||||
|
self.rescale_timesteps = rescale_timesteps
|
||||||
|
self.original_num_steps = original_num_steps
|
||||||
|
|
||||||
|
def __call__(self, x, ts, **kwargs):
|
||||||
|
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
|
||||||
|
new_ts = map_tensor[ts]
|
||||||
|
if self.rescale_timesteps:
|
||||||
|
new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
|
||||||
|
return self.model(x, new_ts, **kwargs)
|
|
@ -2,6 +2,7 @@ import torch
|
||||||
|
|
||||||
from models.diffusion.gaussian_diffusion import GaussianDiffusion, get_named_beta_schedule
|
from models.diffusion.gaussian_diffusion import GaussianDiffusion, get_named_beta_schedule
|
||||||
from models.diffusion.resample import create_named_schedule_sampler
|
from models.diffusion.resample import create_named_schedule_sampler
|
||||||
|
from models.diffusion.respace import space_timesteps, SpacedDiffusion
|
||||||
from trainer.inject import Injector
|
from trainer.inject import Injector
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
@ -12,8 +13,11 @@ class GaussianDiffusionInjector(Injector):
|
||||||
def __init__(self, opt, env):
|
def __init__(self, opt, env):
|
||||||
super().__init__(opt, env)
|
super().__init__(opt, env)
|
||||||
self.generator = opt['generator']
|
self.generator = opt['generator']
|
||||||
|
self.output_variational_bounds_key = opt['out_key_vb_loss']
|
||||||
|
self.output_x_start_key = opt['out_key_x_start']
|
||||||
opt['diffusion_args']['betas'] = get_named_beta_schedule(**opt['beta_schedule'])
|
opt['diffusion_args']['betas'] = get_named_beta_schedule(**opt['beta_schedule'])
|
||||||
self.diffusion = GaussianDiffusion(**opt['diffusion_args'])
|
opt['diffusion_args']['use_timesteps'] = space_timesteps(opt['beta_schedule']['num_diffusion_timesteps'], [opt['beta_schedule']['num_diffusion_timesteps']]) # TODO: Figure out how these work and specify them differently.
|
||||||
|
self.diffusion = SpacedDiffusion(**opt['diffusion_args'])
|
||||||
self.schedule_sampler = create_named_schedule_sampler(opt['sampler_type'], self.diffusion)
|
self.schedule_sampler = create_named_schedule_sampler(opt['sampler_type'], self.diffusion)
|
||||||
self.model_input_keys = opt_get(opt, ['model_input_keys'], [])
|
self.model_input_keys = opt_get(opt, ['model_input_keys'], [])
|
||||||
|
|
||||||
|
@ -22,7 +26,10 @@ class GaussianDiffusionInjector(Injector):
|
||||||
hq = state[self.input]
|
hq = state[self.input]
|
||||||
model_inputs = {k: state[v] for k, v in self.model_input_keys.items()}
|
model_inputs = {k: state[v] for k, v in self.model_input_keys.items()}
|
||||||
t, weights = self.schedule_sampler.sample(hq.shape[0], hq.device)
|
t, weights = self.schedule_sampler.sample(hq.shape[0], hq.device)
|
||||||
return {self.output: self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs)['loss'] * weights}
|
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs)
|
||||||
|
return {self.output: diffusion_outputs['mse'],
|
||||||
|
self.output_variational_bounds_key: diffusion_outputs['vb'],
|
||||||
|
self.output_x_start_key: diffusion_outputs['x_start_predicted']}
|
||||||
|
|
||||||
|
|
||||||
# Performs inference using a network trained to predict a reverse diffusion process, which nets a image.
|
# Performs inference using a network trained to predict a reverse diffusion process, which nets a image.
|
||||||
|
@ -32,7 +39,8 @@ class GaussianDiffusionInferenceInjector(Injector):
|
||||||
self.generator = opt['generator']
|
self.generator = opt['generator']
|
||||||
self.output_shape = opt['output_shape']
|
self.output_shape = opt['output_shape']
|
||||||
opt['diffusion_args']['betas'] = get_named_beta_schedule(**opt['beta_schedule'])
|
opt['diffusion_args']['betas'] = get_named_beta_schedule(**opt['beta_schedule'])
|
||||||
self.diffusion = GaussianDiffusion(**opt['diffusion_args'])
|
opt['diffusion_args']['use_timesteps'] = space_timesteps(opt['beta_schedule']['num_diffusion_timesteps'], [opt['beta_schedule']['num_diffusion_timesteps']]) # TODO: Figure out how these work and specify them differently.
|
||||||
|
self.diffusion = SpacedDiffusion(**opt['diffusion_args'])
|
||||||
self.model_input_keys = opt_get(opt, ['model_input_keys'], [])
|
self.model_input_keys = opt_get(opt, ['model_input_keys'], [])
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
|
|
109
recipes/ddpm/train_ddpm_rrdb.yml
Normal file
109
recipes/ddpm/train_ddpm_rrdb.yml
Normal file
|
@ -0,0 +1,109 @@
|
||||||
|
#### general settings
|
||||||
|
name: train_imgset_rrdb_diffusion
|
||||||
|
model: extensibletrainer
|
||||||
|
scale: 1
|
||||||
|
gpu_ids: [0]
|
||||||
|
start_step: -1
|
||||||
|
checkpointing_enabled: true
|
||||||
|
fp16: false
|
||||||
|
use_tb_logger: true
|
||||||
|
wandb: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
train:
|
||||||
|
n_workers: 4
|
||||||
|
batch_size: 32
|
||||||
|
name: div2k
|
||||||
|
mode: single_image_extensible
|
||||||
|
paths: /content/div2k # <-- Put your path here.
|
||||||
|
target_size: 128
|
||||||
|
force_multiple: 1
|
||||||
|
scale: 4
|
||||||
|
num_corrupts_per_image: 0
|
||||||
|
|
||||||
|
networks:
|
||||||
|
generator:
|
||||||
|
type: generator
|
||||||
|
which_model_G: rrdb_diffusion
|
||||||
|
args:
|
||||||
|
in_channels: 6
|
||||||
|
out_channels: 6
|
||||||
|
num_blocks: 10
|
||||||
|
|
||||||
|
#### path
|
||||||
|
path:
|
||||||
|
#pretrain_model_generator: <insert pretrained model path if desired>
|
||||||
|
strict_load: true
|
||||||
|
#resume_state: ../experiments/train_imgset_rrdb_diffusion/training_state/0.state # <-- Set this to resume from a previous training state.
|
||||||
|
|
||||||
|
steps:
|
||||||
|
generator:
|
||||||
|
training: generator
|
||||||
|
|
||||||
|
optimizer_params:
|
||||||
|
lr: !!float 3e-4
|
||||||
|
weight_decay: !!float 1e-2
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.9999
|
||||||
|
|
||||||
|
injectors:
|
||||||
|
# "Do it all injector": produces a reverse prediction and calculates losses on it.
|
||||||
|
diffusion:
|
||||||
|
type: gaussian_diffusion
|
||||||
|
in: hq
|
||||||
|
generator: generator
|
||||||
|
beta_schedule:
|
||||||
|
schedule_name: linear
|
||||||
|
num_diffusion_timesteps: 4000
|
||||||
|
diffusion_args:
|
||||||
|
model_mean_type: epsilon
|
||||||
|
model_var_type: learned_range
|
||||||
|
loss_type: mse
|
||||||
|
sampler_type: uniform
|
||||||
|
model_input_keys:
|
||||||
|
low_res: lq
|
||||||
|
out: loss
|
||||||
|
|
||||||
|
# Injector for visualizing what your network is doing (every 500 steps)
|
||||||
|
visual_debug:
|
||||||
|
every: 500
|
||||||
|
type: gaussian_diffusion_inference
|
||||||
|
generator: generator
|
||||||
|
output_shape: [8,3,128,128] # Change "8" to your desired output batch size.
|
||||||
|
beta_schedule:
|
||||||
|
schedule_name: linear
|
||||||
|
num_diffusion_timesteps: 500 # Change higher (up to training steps) for improved quality. Lower for faster speed.
|
||||||
|
diffusion_args:
|
||||||
|
model_mean_type: epsilon
|
||||||
|
model_var_type: learned_range
|
||||||
|
loss_type: mse
|
||||||
|
model_input_keys:
|
||||||
|
low_res: lq
|
||||||
|
out: sample
|
||||||
|
|
||||||
|
losses:
|
||||||
|
diffusion_loss:
|
||||||
|
type: direct
|
||||||
|
weight: 1
|
||||||
|
key: loss
|
||||||
|
|
||||||
|
train:
|
||||||
|
niter: 500000
|
||||||
|
warmup_iter: -1
|
||||||
|
mega_batch_factor: 1 # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8].
|
||||||
|
val_freq: 4000
|
||||||
|
|
||||||
|
# Default LR scheduler options
|
||||||
|
default_lr_scheme: CosineAnnealingLR_Restart
|
||||||
|
T_period: [ 200000, 200000 ]
|
||||||
|
warmup: 0
|
||||||
|
eta_min: !!float 1e-7
|
||||||
|
restarts: [ 200000, 400000 ]
|
||||||
|
restart_weights: [ .5, .5 ]
|
||||||
|
|
||||||
|
logger:
|
||||||
|
print_freq: 30
|
||||||
|
save_checkpoint_freq: 2000
|
||||||
|
visuals: [sample, hq, lq]
|
||||||
|
visual_debug_rate: 500
|
||||||
|
reverse_n1_to_1: true
|
Loading…
Reference in New Issue
Block a user