diffuse the cascaded prior for continuous sr model

This commit is contained in:
James Betker 2022-07-20 11:54:09 -06:00
parent b0e3be0a17
commit 82bd62019f
2 changed files with 55 additions and 24 deletions

View File

@ -1,10 +1,10 @@
import itertools import itertools
import random
from random import randrange from random import randrange
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchvision.utils
from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
@ -88,7 +88,7 @@ class ConditioningEncoder(nn.Module):
def forward(self, x, resolution): def forward(self, x, resolution):
h = self.init(x) + self.resolution_embedding(resolution).unsqueeze(-1) h = self.init(x) + self.resolution_embedding(resolution).unsqueeze(-1)
h = self.attn(h) h = self.attn(h)
return h[:, :, :6] return h[:, :, :5]
class TransformerDiffusion(nn.Module): class TransformerDiffusion(nn.Module):
@ -130,10 +130,14 @@ class TransformerDiffusion(nn.Module):
nn.SiLU(), nn.SiLU(),
linear(time_embed_dim, model_channels), linear(time_embed_dim, model_channels),
) )
self.prior_time_embed = nn.Sequential(
linear(time_embed_dim, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, model_channels),
)
self.resolution_embed = nn.Embedding(resolution_steps, model_channels) self.resolution_embed = nn.Embedding(resolution_steps, model_channels)
self.conditioning_encoder = ConditioningEncoder(in_channels, model_channels, resolution_steps, num_attn_heads=model_channels//64) self.conditioning_encoder = ConditioningEncoder(in_channels, model_channels, resolution_steps, num_attn_heads=model_channels//64)
self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,6)) self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,5))
self.unconditioned_prior = nn.Parameter(torch.zeros(1,in_channels,1))
self.inp_block = conv_nd(1, in_channels+input_vec_dim, model_channels, 3, 1, 1) self.inp_block = conv_nd(1, in_channels+input_vec_dim, model_channels, 3, 1, 1)
self.layers = TimestepEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, num_heads, dropout) for _ in range(num_layers)]) self.layers = TimestepEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, num_heads, dropout) for _ in range(num_layers)])
@ -169,7 +173,7 @@ class TransformerDiffusion(nn.Module):
} }
return groups return groups
def input_to_random_resolution_and_window(self, x): def input_to_random_resolution_and_window(self, x, ts, diffuser):
""" """
This function MUST be applied to the target *before* noising. It returns the reduced, re-scoped target as well This function MUST be applied to the target *before* noising. It returns the reduced, re-scoped target as well
as caches an internal prior for the rescoped target which will be used in training. as caches an internal prior for the rescoped target which will be used in training.
@ -185,26 +189,47 @@ class TransformerDiffusion(nn.Module):
s = s[:,:,start:start+self.max_window] s = s[:,:,start:start+self.max_window]
s_prior = F.interpolate(s, scale_factor=.25, mode='nearest') s_prior = F.interpolate(s, scale_factor=.25, mode='nearest')
s_prior = F.interpolate(s_prior, size=(s.shape[-1],), mode='linear', align_corners=True) s_prior = F.interpolate(s_prior, size=(s.shape[-1],), mode='linear', align_corners=True)
self.preprocessed = (s_prior, torch.tensor([resolution] * x.shape[0], dtype=torch.long, device=x.device))
# Now diffuse the prior randomly between the x timestep and 0.
adv = torch.rand_like(ts.float())
t_prior = (adv * ts).long()
s_prior_diffused = diffuser.q_sample(s_prior, t_prior, torch.randn_like(s_prior))
self.preprocessed = (s_prior_diffused, t_prior, torch.tensor([resolution] * x.shape[0], dtype=torch.long, device=x.device))
return s return s
def forward(self, x, timesteps, x_prior=None, resolution=None, conditioning_input=None, conditioning_free=False): def forward(self, x, timesteps, prior_timesteps=None, x_prior=None, resolution=None, conditioning_input=None, conditioning_free=False):
"""
Predicts the previous diffusion timestep of x, given a partially diffused low-resolution prior and a conditioning
input.
All parameters are optional because during training, input_to_random_resolution_and_window is used by a training
harness to preformat the inputs and fill in the parameters as state variables.
Args:
x: Prediction prior.
timesteps: Number of timesteps x has been diffused for.
prior_timesteps: Number of timesteps x_prior has been diffused for. Must be <= timesteps for each batch element.
x_prior: A low-resolution prior that guides the model.
resolution: Integer indicating the operating resolution level. '0' is the highest resolution.
conditioning_input: A semi-related (un-aligned) conditioning input which is used to guide diffusion. Similar to a class input, but hooked to a learned conditioning encoder.
conditioning_free: Whether or not to ignore the conditioning input.
"""
conditioning_input = x_prior if conditioning_input is None else conditioning_input conditioning_input = x_prior if conditioning_input is None else conditioning_input
h = x
if resolution is None: if resolution is None:
# This is assumed to be training. # This is assumed to be training.
assert self.preprocessed is not None, 'Preprocessing function not called.' assert self.preprocessed is not None, 'Preprocessing function not called.'
assert x_prior is None, 'Provided prior will not be used, instead preprocessing output will be used.' assert x_prior is None, 'Provided prior will not be used, instead preprocessing output will be used.'
h_sub, resolution = self.preprocessed x_prior, prior_timesteps, resolution = self.preprocessed
self.preprocessed = None self.preprocessed = None
else: else:
assert h.shape[-1] > x_prior.shape[-1] * 3.9, f'{h.shape} {x_prior.shape}' assert x.shape[-1] > x_prior.shape[-1] * 3.9, f'{x.shape} {x_prior.shape}'
h_sub = F.interpolate(x_prior, size=(x.shape[-1],), mode='linear', align_corners=True) x_prior = F.interpolate(x_prior, size=(x.shape[-1],), mode='linear', align_corners=True)
assert torch.all(timesteps - prior_timesteps > 0), f'Prior timesteps should always be lower (more resolved) than input timesteps. {timesteps}, {prior_timesteps}'
if conditioning_free: if conditioning_free:
h_sub = self.unconditioned_prior.repeat(x.shape[0], 1, x.shape[-1]) code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1)
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
else: else:
MIN_COND_LEN = 200 MIN_COND_LEN = 200
MAX_COND_LEN = 1200 MAX_COND_LEN = 1200
@ -217,17 +242,17 @@ class TransformerDiffusion(nn.Module):
# Mask out the conditioning input and x_prior inputs for whole batch elements, implementing something similar to classifier-free guidance. # Mask out the conditioning input and x_prior inputs for whole batch elements, implementing something similar to classifier-free guidance.
if self.training and self.unconditioned_percentage > 0: if self.training and self.unconditioned_percentage > 0:
unconditioned_batches = torch.rand((h.shape[0], 1, 1), unconditioned_batches = torch.rand((x.shape[0], 1, 1),
device=h.device) < self.unconditioned_percentage device=x.device) < self.unconditioned_percentage
h_sub = torch.where(unconditioned_batches, self.unconditioned_prior.repeat(h_sub.shape[0], 1, h_sub.shape[-1]), h_sub)
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(code_emb.shape[0], 1, 1), code_emb) code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(code_emb.shape[0], 1, 1), code_emb)
with torch.autocast(x.device.type, enabled=self.enable_fp16): with torch.autocast(x.device.type, enabled=self.enable_fp16):
time_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim)) time_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim))
prior_time_emb = self.prior_time_embed(timestep_embedding(prior_timesteps, self.time_embed_dim))
res_emb = self.resolution_embed(resolution) res_emb = self.resolution_embed(resolution)
blk_emb = torch.cat([time_emb.unsqueeze(-1), res_emb.unsqueeze(-1), code_emb], dim=-1) blk_emb = torch.cat([time_emb.unsqueeze(-1), prior_time_emb.unsqueeze(-1), res_emb.unsqueeze(-1), code_emb], dim=-1)
h = torch.cat([h, h_sub], dim=1) h = torch.cat([x, x_prior], dim=1)
h = self.inp_block(h) h = self.inp_block(h)
for layer in self.layers: for layer in self.layers:
h = checkpoint(layer, h, blk_emb) h = checkpoint(layer, h, blk_emb)
@ -236,7 +261,7 @@ class TransformerDiffusion(nn.Module):
out = self.out(h) out = self.out(h)
# Defensively involve probabilistic or possibly unused parameters in loss so we don't get DDP errors. # Defensively involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
unused_params = [self.unconditioned_prior, self.unconditioned_embedding] unused_params = [self.unconditioned_embedding]
extraneous_addition = 0 extraneous_addition = 0
for p in unused_params: for p in unused_params:
extraneous_addition = extraneous_addition + p.mean() extraneous_addition = extraneous_addition + p.mean()
@ -251,6 +276,12 @@ def register_transformer_diffusion13(opt_net, opt):
def test_tfd(): def test_tfd():
from models.diffusion.respace import SpacedDiffusion
from models.diffusion.respace import space_timesteps
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [4000]), model_mean_type='epsilon',
model_var_type='learned_range', loss_type='mse',
betas=get_named_beta_schedule('linear', 4000))
clip = torch.randn(2,256,10336) clip = torch.randn(2,256,10336)
cond = torch.randn(2,256,10336) cond = torch.randn(2,256,10336)
ts = torch.LongTensor([600, 600]) ts = torch.LongTensor([600, 600])
@ -258,8 +289,8 @@ def test_tfd():
num_heads=512//64, input_vec_dim=256, num_layers=12, dropout=.1, num_heads=512//64, input_vec_dim=256, num_layers=12, dropout=.1,
unconditioned_percentage=.6) unconditioned_percentage=.6)
for k in range(100): for k in range(100):
x = model.input_to_random_resolution_and_window(clip, x_prior=clip) x = model.input_to_random_resolution_and_window(clip, ts, diffuser)
model(x, ts, clip) model(x, ts, conditioning_input=cond)
def remove_conditioning(sd_path): def remove_conditioning(sd_path):

View File

@ -89,10 +89,10 @@ class GaussianDiffusionInjector(Injector):
sampler = self.schedule_sampler sampler = self.schedule_sampler
self.deterministic_sampler.reset() # Keep this reset whenever it is not being used, so it is ready to use automatically. self.deterministic_sampler.reset() # Keep this reset whenever it is not being used, so it is ready to use automatically.
model_inputs = {k: state[v] if isinstance(v, str) else v for k, v in self.model_input_keys.items()} model_inputs = {k: state[v] if isinstance(v, str) else v for k, v in self.model_input_keys.items()}
if self.preprocess_fn is not None:
hq = getattr(gen.module, self.preprocess_fn)(hq)
t, weights = sampler.sample(hq.shape[0], hq.device) t, weights = sampler.sample(hq.shape[0], hq.device)
if self.preprocess_fn is not None:
hq = getattr(gen.module, self.preprocess_fn)(hq, t, self.diffusion)
if self.causal_mode: if self.causal_mode:
cs, ce = self.causal_slope_range cs, ce = self.causal_slope_range
slope = random.random() * (ce-cs) + cs slope = random.random() * (ce-cs) + cs