Support causal diffusion!

This commit is contained in:
James Betker 2022-07-08 12:30:05 -06:00
parent 78bba690de
commit 7b4dcbf136
6 changed files with 108 additions and 53 deletions

View File

@ -81,11 +81,14 @@ class ConditioningEncoder(nn.Module):
attn_blocks=6,
num_attn_heads=8,
dropout=.1,
do_checkpointing=False):
do_checkpointing=False,
time_proj=True):
super().__init__()
attn = []
self.init = nn.Conv1d(cond_dim, embedding_dim, kernel_size=1)
self.time_proj = nn.Linear(time_embed_dim, embedding_dim)
self.time_proj = time_proj
if time_proj:
self.time_proj = nn.Linear(time_embed_dim, embedding_dim)
self.attn = Encoder(
dim=embedding_dim,
depth=attn_blocks,
@ -103,8 +106,9 @@ class ConditioningEncoder(nn.Module):
def forward(self, x, time_emb):
h = self.init(x).permute(0,2,1)
time_enc = self.time_proj(time_emb)
h = torch.cat([time_enc.unsqueeze(1), h], dim=1)
if self.time_proj:
time_enc = self.time_proj(time_emb)
h = torch.cat([time_enc.unsqueeze(1), h], dim=1)
h = self.attn(h).permute(0,2,1)
return h
@ -125,6 +129,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
input_cond_dim=1024,
num_heads=8,
dropout=0,
time_proj=False,
use_fp16=False,
checkpoint_conditioning=True, # This will need to be false for DDP training. :(
# Parameters for regularization.
@ -141,7 +146,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
self.enable_fp16 = use_fp16
self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1)
self.conditioning_encoder = ConditioningEncoder(256, model_channels, time_embed_dim, do_checkpointing=checkpoint_conditioning)
self.conditioning_encoder = ConditioningEncoder(256, model_channels, time_embed_dim, do_checkpointing=checkpoint_conditioning, time_proj=time_proj)
self.time_embed = nn.Sequential(
linear(time_embed_dim, time_embed_dim),
@ -210,11 +215,12 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
cond_left = cond_aligned[:,:,:break_pt]
cond_right = cond_aligned[:,:,break_pt:]
# Drop out a random amount of the aligned data. The network will need to figure out how to reconstruct this.
to_remove_left = random.randint(1, cond_left.shape[-1]-MIN_MARGIN)
cond_left = cond_left[:,:,:-to_remove_left]
to_remove_right = random.randint(1, cond_right.shape[-1]-MIN_MARGIN)
cond_right = cond_right[:,:,to_remove_right:]
if self.training:
# Drop out a random amount of the aligned data. The network will need to figure out how to reconstruct this.
to_remove_left = random.randint(1, cond_left.shape[-1]-MIN_MARGIN)
cond_left = cond_left[:,:,:-to_remove_left]
to_remove_right = random.randint(1, cond_right.shape[-1]-MIN_MARGIN)
cond_right = cond_right[:,:,to_remove_right:]
# Concatenate the _pre and _post back on.
cond_left_full = torch.cat([cond_pre, cond_left], dim=-1)

View File

@ -7,14 +7,15 @@ Docstrings have been added, as well as DDIM sampling and a new collection of bet
import enum
import math
import random
import numpy as np
import torch
import torch as th
from tqdm import tqdm
from .nn import mean_flat
from .losses import normal_kl, discretized_gaussian_log_likelihood
from models.diffusion.nn import mean_flat
from models.diffusion.losses import normal_kl, discretized_gaussian_log_likelihood
def causal_timestep_adjustment(t, S, num_timesteps, causal_slope=1, add_jitter=True):
@ -37,43 +38,18 @@ def causal_timestep_adjustment(t, S, num_timesteps, causal_slope=1, add_jitter=T
# This algorithm for adding causality does so by simply adding S_sloped additional timesteps. To make this
# actually work, we map the existing t from the timescale specified to the model to the causal timescale:
adj_t = t * (num_timesteps + S_sloped) // num_timesteps
adj_t = adj_t - S_sloped
if add_jitter:
jitter = (random.random() - .5) * S_sloped
adj_t = (adj_t+jitter).clamp(0, num_timesteps+S_sloped)
t_gap = (num_timesteps + S_sloped) / num_timesteps
jitter = (2*random.random()-1) * t_gap
adj_t = (adj_t+jitter).clamp(-S_sloped, num_timesteps)
# Now use the re-mapped adj_t to create a timestep vector that propagates across the sequence with the specified slope.
t = adj_t.unsqueeze(1).repeat(1, S)
t = (t - torch.arange(0, S) * causal_slope).clamp(0, num_timesteps).long()
t = (t + torch.arange(0, S, device=t.device) * causal_slope).clamp(0, num_timesteps).long()
return t
def graph_causal_timestep_adjustment():
S = 400
slope=4
num_timesteps=4000
#for num_timesteps in range(100, 4000, 200):
t_res = []
for t in range(num_timesteps, -1, -num_timesteps//50):
T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0]
t_res.append(T)
plt.plot(T.numpy())
plt.ylim(0,4000)
plt.xlim(0,500)
plt.savefig(f'{t}.png')
plt.clf()
for i in range(len(t_res)):
for j in range(len(t_res)):
if i == j:
continue
#assert not torch.all(t_res[i] == t_res[j])
plt.ylim(0,4000)
plt.xlim(0,500)
plt.ylabel('timestep')
plt.savefig(f'{num_timesteps}.png')
plt.clf()
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
"""
Get a pre-defined beta schedule for the given name.
@ -319,7 +295,7 @@ class GaussianDiffusion:
model_kwargs = {}
B, C = x.shape[:2]
assert t.shape == (B,)
assert t.shape == (B,) or t.shape == (B,1,x.shape[-1])
model_output = model(x, self._scale_timesteps(t), **model_kwargs)
if self.conditioning_free:
model_output_no_conditioning = model(x, self._scale_timesteps(t), conditioning_free=True, **model_kwargs)
@ -844,7 +820,10 @@ class GaussianDiffusion:
# At the first timestep return the decoder NLL,
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
output = th.where((t == 0).view(-1, 1, 1), decoder_nll, kl)
if len(t.shape) == 1:
output = th.where((t == 0).view(-1, 1, 1), decoder_nll, kl)
else:
output = th.where((t == 0), decoder_nll, kl)
return {"output": output, "pred_xstart": out["pred_xstart"]}
def causal_training_losses(self, model, x_start, t, causal_slope=1, model_kwargs=None, noise=None, channel_balancing_fn=None):
@ -852,8 +831,9 @@ class GaussianDiffusion:
Compute training losses for a causal diffusion process.
"""
assert len(x_start.shape) == 3, "causal_training_losses assumes a 1d sequence with the axis being the time axis."
t = causal_timestep_adjustment(t, x_start.shape[-1], self.num_timesteps, causal_slope, add_jitter=True)
return self.training_losses(model, x_start, t, model_kwargs, noise, channel_balancing_fn)
ct = causal_timestep_adjustment(t, x_start.shape[-1], self.num_timesteps, causal_slope, add_jitter=True)
ct = ct.unsqueeze(1) # Necessary to make the output shape compatible with x_start.
return self.training_losses(model, x_start, ct, model_kwargs, noise, channel_balancing_fn)
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, channel_balancing_fn=None):
"""
@ -872,6 +852,13 @@ class GaussianDiffusion:
model_kwargs = {}
if noise is None:
noise = th.randn_like(x_start)
if len(t.shape) == 3:
t_mask = t != self.num_timesteps
t[t_mask.logical_not()] = self.num_timesteps-1
else:
t_mask = torch.ones_like(x_start)
x_t = self.q_sample(x_start, t, noise=noise)
terms = {}
@ -912,7 +899,7 @@ class GaussianDiffusion:
x_t=x_t,
t=t,
clip_denoised=False,
)["output"]
)["output"] * t_mask
if self.loss_type == LossType.RESCALED_MSE:
# Divide by 1000 for equivalence with initial implementation.
# Without a factor of 1/1000, the VB term hurts the MSE term.
@ -932,7 +919,7 @@ class GaussianDiffusion:
else:
raise NotImplementedError(self.model_mean_type)
assert model_output.shape == target.shape == x_start.shape
s_err = (target - model_output) ** 2
s_err = t_mask * (target - model_output) ** 2
if channel_balancing_fn is not None:
s_err = channel_balancing_fn(s_err)
terms["mse_by_batch"] = s_err.reshape(s_err.shape[0], -1).mean(dim=1)
@ -1039,3 +1026,47 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res.expand(broadcast_shape)
def test_causal_training_losses():
from models.diffusion.respace import SpacedDiffusion
from models.diffusion.respace import space_timesteps
diff = SpacedDiffusion(use_timesteps=space_timesteps(4000, [4000]), model_mean_type='epsilon',
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000),
conditioning_free=False, conditioning_free_k=1)
class IdentityTwoArg(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, *args, **kwargs):
return x.repeat(1,2,1)
model = IdentityTwoArg()
diff.causal_training_losses(model, torch.randn(4,256,400), torch.tensor([500,1000,3000,3500]), causal_slope=4)
def graph_causal_timestep_adjustment():
import matplotlib.pyplot as plt
S = 400
#slope=4
num_timesteps=4000
for slpe in range(10, 400, 10):
slope = slpe / 10
t_res = []
for t in range(num_timesteps, -1, -num_timesteps//50):
T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0]
t_res.append(T)
plt.plot(T.numpy())
for i in range(len(t_res)):
for j in range(len(t_res)):
if i == j:
continue
#assert not torch.all(t_res[i] == t_res[j])
plt.ylim(0,num_timesteps)
plt.xlim(0,4000)
plt.ylabel('timestep')
plt.savefig(f'{slpe}.png')
plt.clf()
if __name__ == '__main__':
#test_causal_training_losses()
graph_causal_timestep_adjustment()

View File

@ -122,7 +122,10 @@ def timestep_embedding(timesteps, dim, max_period=10000):
freqs = th.exp(
-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
if len(timesteps.shape) == 1:
args = timesteps[:, None].float() * freqs[None]
else:
args = (timesteps.float() * freqs.view(1,half,1)).permute(0,2,1)
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
if dim % 2:
embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)

View File

@ -365,8 +365,11 @@ class RMSScaleShiftNorm(nn.Module):
norm = x / norm.clamp(min=self.eps) * self.g
ss_emb = self.scale_shift_process(norm_scale_shift_inp)
scale, shift = torch.chunk(ss_emb, 2, dim=1)
h = norm * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
scale, shift = torch.chunk(ss_emb, 2, dim=-1)
if len(scale.shape) == 2:
scale = scale.unsqueeze(1)
shift = shift.unsqueeze(1)
h = norm * (1 + scale) + shift
return h

View File

@ -339,7 +339,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_tfd12_finetune_ar_outputs.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_cheater_gen.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)

View File

@ -1,4 +1,5 @@
import functools
import random
import torch
from torch.cuda.amp import autocast
@ -44,6 +45,8 @@ class GaussianDiffusionInjector(Injector):
self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], [])
self.deterministic_timesteps_every = opt_get(opt, ['deterministic_timesteps_every'], 0)
self.deterministic_sampler = DeterministicSampler(self.diffusion, opt_get(opt, ['deterministic_sampler_expected_batch_size'], 2048), env)
self.causal_mode = opt_get(opt, ['causal_mode'], False)
self.causal_slope_range = opt_get(opt, ['causal_slope_range'], [1,8])
k = 0
if 'channel_balancer_proportion' in opt.keys():
@ -86,7 +89,16 @@ class GaussianDiffusionInjector(Injector):
self.deterministic_sampler.reset() # Keep this reset whenever it is not being used, so it is ready to use automatically.
model_inputs = {k: state[v] if isinstance(v, str) else v for k, v in self.model_input_keys.items()}
t, weights = sampler.sample(hq.shape[0], hq.device)
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs, channel_balancing_fn=self.channel_balancing_fn)
if self.causal_mode:
cs, ce = self.causal_slope_range
slope = random.random() * (ce-cs) + cs
diffusion_outputs = self.diffusion.causal_training_losses(gen, hq, t, model_kwargs=model_inputs,
channel_balancing_fn=self.channel_balancing_fn,
causal_slope=slope)
else:
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs,
channel_balancing_fn=self.channel_balancing_fn)
if isinstance(sampler, LossAwareSampler):
sampler.update_with_local_losses(t, diffusion_outputs['loss'])
if len(self.extra_model_output_keys) > 0: