Support causal diffusion!

This commit is contained in:
James Betker 2022-07-08 12:30:05 -06:00
parent 78bba690de
commit 7b4dcbf136
6 changed files with 108 additions and 53 deletions

View File

@ -81,11 +81,14 @@ class ConditioningEncoder(nn.Module):
attn_blocks=6, attn_blocks=6,
num_attn_heads=8, num_attn_heads=8,
dropout=.1, dropout=.1,
do_checkpointing=False): do_checkpointing=False,
time_proj=True):
super().__init__() super().__init__()
attn = [] attn = []
self.init = nn.Conv1d(cond_dim, embedding_dim, kernel_size=1) self.init = nn.Conv1d(cond_dim, embedding_dim, kernel_size=1)
self.time_proj = nn.Linear(time_embed_dim, embedding_dim) self.time_proj = time_proj
if time_proj:
self.time_proj = nn.Linear(time_embed_dim, embedding_dim)
self.attn = Encoder( self.attn = Encoder(
dim=embedding_dim, dim=embedding_dim,
depth=attn_blocks, depth=attn_blocks,
@ -103,8 +106,9 @@ class ConditioningEncoder(nn.Module):
def forward(self, x, time_emb): def forward(self, x, time_emb):
h = self.init(x).permute(0,2,1) h = self.init(x).permute(0,2,1)
time_enc = self.time_proj(time_emb) if self.time_proj:
h = torch.cat([time_enc.unsqueeze(1), h], dim=1) time_enc = self.time_proj(time_emb)
h = torch.cat([time_enc.unsqueeze(1), h], dim=1)
h = self.attn(h).permute(0,2,1) h = self.attn(h).permute(0,2,1)
return h return h
@ -125,6 +129,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
input_cond_dim=1024, input_cond_dim=1024,
num_heads=8, num_heads=8,
dropout=0, dropout=0,
time_proj=False,
use_fp16=False, use_fp16=False,
checkpoint_conditioning=True, # This will need to be false for DDP training. :( checkpoint_conditioning=True, # This will need to be false for DDP training. :(
# Parameters for regularization. # Parameters for regularization.
@ -141,7 +146,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
self.enable_fp16 = use_fp16 self.enable_fp16 = use_fp16
self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1) self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1)
self.conditioning_encoder = ConditioningEncoder(256, model_channels, time_embed_dim, do_checkpointing=checkpoint_conditioning) self.conditioning_encoder = ConditioningEncoder(256, model_channels, time_embed_dim, do_checkpointing=checkpoint_conditioning, time_proj=time_proj)
self.time_embed = nn.Sequential( self.time_embed = nn.Sequential(
linear(time_embed_dim, time_embed_dim), linear(time_embed_dim, time_embed_dim),
@ -210,11 +215,12 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
cond_left = cond_aligned[:,:,:break_pt] cond_left = cond_aligned[:,:,:break_pt]
cond_right = cond_aligned[:,:,break_pt:] cond_right = cond_aligned[:,:,break_pt:]
# Drop out a random amount of the aligned data. The network will need to figure out how to reconstruct this. if self.training:
to_remove_left = random.randint(1, cond_left.shape[-1]-MIN_MARGIN) # Drop out a random amount of the aligned data. The network will need to figure out how to reconstruct this.
cond_left = cond_left[:,:,:-to_remove_left] to_remove_left = random.randint(1, cond_left.shape[-1]-MIN_MARGIN)
to_remove_right = random.randint(1, cond_right.shape[-1]-MIN_MARGIN) cond_left = cond_left[:,:,:-to_remove_left]
cond_right = cond_right[:,:,to_remove_right:] to_remove_right = random.randint(1, cond_right.shape[-1]-MIN_MARGIN)
cond_right = cond_right[:,:,to_remove_right:]
# Concatenate the _pre and _post back on. # Concatenate the _pre and _post back on.
cond_left_full = torch.cat([cond_pre, cond_left], dim=-1) cond_left_full = torch.cat([cond_pre, cond_left], dim=-1)

View File

@ -7,14 +7,15 @@ Docstrings have been added, as well as DDIM sampling and a new collection of bet
import enum import enum
import math import math
import random
import numpy as np import numpy as np
import torch import torch
import torch as th import torch as th
from tqdm import tqdm from tqdm import tqdm
from .nn import mean_flat from models.diffusion.nn import mean_flat
from .losses import normal_kl, discretized_gaussian_log_likelihood from models.diffusion.losses import normal_kl, discretized_gaussian_log_likelihood
def causal_timestep_adjustment(t, S, num_timesteps, causal_slope=1, add_jitter=True): def causal_timestep_adjustment(t, S, num_timesteps, causal_slope=1, add_jitter=True):
@ -37,43 +38,18 @@ def causal_timestep_adjustment(t, S, num_timesteps, causal_slope=1, add_jitter=T
# This algorithm for adding causality does so by simply adding S_sloped additional timesteps. To make this # This algorithm for adding causality does so by simply adding S_sloped additional timesteps. To make this
# actually work, we map the existing t from the timescale specified to the model to the causal timescale: # actually work, we map the existing t from the timescale specified to the model to the causal timescale:
adj_t = t * (num_timesteps + S_sloped) // num_timesteps adj_t = t * (num_timesteps + S_sloped) // num_timesteps
adj_t = adj_t - S_sloped
if add_jitter: if add_jitter:
jitter = (random.random() - .5) * S_sloped t_gap = (num_timesteps + S_sloped) / num_timesteps
adj_t = (adj_t+jitter).clamp(0, num_timesteps+S_sloped) jitter = (2*random.random()-1) * t_gap
adj_t = (adj_t+jitter).clamp(-S_sloped, num_timesteps)
# Now use the re-mapped adj_t to create a timestep vector that propagates across the sequence with the specified slope. # Now use the re-mapped adj_t to create a timestep vector that propagates across the sequence with the specified slope.
t = adj_t.unsqueeze(1).repeat(1, S) t = adj_t.unsqueeze(1).repeat(1, S)
t = (t - torch.arange(0, S) * causal_slope).clamp(0, num_timesteps).long() t = (t + torch.arange(0, S, device=t.device) * causal_slope).clamp(0, num_timesteps).long()
return t return t
def graph_causal_timestep_adjustment():
S = 400
slope=4
num_timesteps=4000
#for num_timesteps in range(100, 4000, 200):
t_res = []
for t in range(num_timesteps, -1, -num_timesteps//50):
T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0]
t_res.append(T)
plt.plot(T.numpy())
plt.ylim(0,4000)
plt.xlim(0,500)
plt.savefig(f'{t}.png')
plt.clf()
for i in range(len(t_res)):
for j in range(len(t_res)):
if i == j:
continue
#assert not torch.all(t_res[i] == t_res[j])
plt.ylim(0,4000)
plt.xlim(0,500)
plt.ylabel('timestep')
plt.savefig(f'{num_timesteps}.png')
plt.clf()
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
""" """
Get a pre-defined beta schedule for the given name. Get a pre-defined beta schedule for the given name.
@ -319,7 +295,7 @@ class GaussianDiffusion:
model_kwargs = {} model_kwargs = {}
B, C = x.shape[:2] B, C = x.shape[:2]
assert t.shape == (B,) assert t.shape == (B,) or t.shape == (B,1,x.shape[-1])
model_output = model(x, self._scale_timesteps(t), **model_kwargs) model_output = model(x, self._scale_timesteps(t), **model_kwargs)
if self.conditioning_free: if self.conditioning_free:
model_output_no_conditioning = model(x, self._scale_timesteps(t), conditioning_free=True, **model_kwargs) model_output_no_conditioning = model(x, self._scale_timesteps(t), conditioning_free=True, **model_kwargs)
@ -844,7 +820,10 @@ class GaussianDiffusion:
# At the first timestep return the decoder NLL, # At the first timestep return the decoder NLL,
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
output = th.where((t == 0).view(-1, 1, 1), decoder_nll, kl) if len(t.shape) == 1:
output = th.where((t == 0).view(-1, 1, 1), decoder_nll, kl)
else:
output = th.where((t == 0), decoder_nll, kl)
return {"output": output, "pred_xstart": out["pred_xstart"]} return {"output": output, "pred_xstart": out["pred_xstart"]}
def causal_training_losses(self, model, x_start, t, causal_slope=1, model_kwargs=None, noise=None, channel_balancing_fn=None): def causal_training_losses(self, model, x_start, t, causal_slope=1, model_kwargs=None, noise=None, channel_balancing_fn=None):
@ -852,8 +831,9 @@ class GaussianDiffusion:
Compute training losses for a causal diffusion process. Compute training losses for a causal diffusion process.
""" """
assert len(x_start.shape) == 3, "causal_training_losses assumes a 1d sequence with the axis being the time axis." assert len(x_start.shape) == 3, "causal_training_losses assumes a 1d sequence with the axis being the time axis."
t = causal_timestep_adjustment(t, x_start.shape[-1], self.num_timesteps, causal_slope, add_jitter=True) ct = causal_timestep_adjustment(t, x_start.shape[-1], self.num_timesteps, causal_slope, add_jitter=True)
return self.training_losses(model, x_start, t, model_kwargs, noise, channel_balancing_fn) ct = ct.unsqueeze(1) # Necessary to make the output shape compatible with x_start.
return self.training_losses(model, x_start, ct, model_kwargs, noise, channel_balancing_fn)
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, channel_balancing_fn=None): def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, channel_balancing_fn=None):
""" """
@ -872,6 +852,13 @@ class GaussianDiffusion:
model_kwargs = {} model_kwargs = {}
if noise is None: if noise is None:
noise = th.randn_like(x_start) noise = th.randn_like(x_start)
if len(t.shape) == 3:
t_mask = t != self.num_timesteps
t[t_mask.logical_not()] = self.num_timesteps-1
else:
t_mask = torch.ones_like(x_start)
x_t = self.q_sample(x_start, t, noise=noise) x_t = self.q_sample(x_start, t, noise=noise)
terms = {} terms = {}
@ -912,7 +899,7 @@ class GaussianDiffusion:
x_t=x_t, x_t=x_t,
t=t, t=t,
clip_denoised=False, clip_denoised=False,
)["output"] )["output"] * t_mask
if self.loss_type == LossType.RESCALED_MSE: if self.loss_type == LossType.RESCALED_MSE:
# Divide by 1000 for equivalence with initial implementation. # Divide by 1000 for equivalence with initial implementation.
# Without a factor of 1/1000, the VB term hurts the MSE term. # Without a factor of 1/1000, the VB term hurts the MSE term.
@ -932,7 +919,7 @@ class GaussianDiffusion:
else: else:
raise NotImplementedError(self.model_mean_type) raise NotImplementedError(self.model_mean_type)
assert model_output.shape == target.shape == x_start.shape assert model_output.shape == target.shape == x_start.shape
s_err = (target - model_output) ** 2 s_err = t_mask * (target - model_output) ** 2
if channel_balancing_fn is not None: if channel_balancing_fn is not None:
s_err = channel_balancing_fn(s_err) s_err = channel_balancing_fn(s_err)
terms["mse_by_batch"] = s_err.reshape(s_err.shape[0], -1).mean(dim=1) terms["mse_by_batch"] = s_err.reshape(s_err.shape[0], -1).mean(dim=1)
@ -1039,3 +1026,47 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
while len(res.shape) < len(broadcast_shape): while len(res.shape) < len(broadcast_shape):
res = res[..., None] res = res[..., None]
return res.expand(broadcast_shape) return res.expand(broadcast_shape)
def test_causal_training_losses():
from models.diffusion.respace import SpacedDiffusion
from models.diffusion.respace import space_timesteps
diff = SpacedDiffusion(use_timesteps=space_timesteps(4000, [4000]), model_mean_type='epsilon',
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000),
conditioning_free=False, conditioning_free_k=1)
class IdentityTwoArg(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, *args, **kwargs):
return x.repeat(1,2,1)
model = IdentityTwoArg()
diff.causal_training_losses(model, torch.randn(4,256,400), torch.tensor([500,1000,3000,3500]), causal_slope=4)
def graph_causal_timestep_adjustment():
import matplotlib.pyplot as plt
S = 400
#slope=4
num_timesteps=4000
for slpe in range(10, 400, 10):
slope = slpe / 10
t_res = []
for t in range(num_timesteps, -1, -num_timesteps//50):
T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0]
t_res.append(T)
plt.plot(T.numpy())
for i in range(len(t_res)):
for j in range(len(t_res)):
if i == j:
continue
#assert not torch.all(t_res[i] == t_res[j])
plt.ylim(0,num_timesteps)
plt.xlim(0,4000)
plt.ylabel('timestep')
plt.savefig(f'{slpe}.png')
plt.clf()
if __name__ == '__main__':
#test_causal_training_losses()
graph_causal_timestep_adjustment()

View File

@ -122,7 +122,10 @@ def timestep_embedding(timesteps, dim, max_period=10000):
freqs = th.exp( freqs = th.exp(
-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
).to(device=timesteps.device) ).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None] if len(timesteps.shape) == 1:
args = timesteps[:, None].float() * freqs[None]
else:
args = (timesteps.float() * freqs.view(1,half,1)).permute(0,2,1)
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
if dim % 2: if dim % 2:
embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)

View File

@ -365,8 +365,11 @@ class RMSScaleShiftNorm(nn.Module):
norm = x / norm.clamp(min=self.eps) * self.g norm = x / norm.clamp(min=self.eps) * self.g
ss_emb = self.scale_shift_process(norm_scale_shift_inp) ss_emb = self.scale_shift_process(norm_scale_shift_inp)
scale, shift = torch.chunk(ss_emb, 2, dim=1) scale, shift = torch.chunk(ss_emb, 2, dim=-1)
h = norm * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) if len(scale.shape) == 2:
scale = scale.unsqueeze(1)
shift = shift.unsqueeze(1)
h = norm * (1 + scale) + shift
return h return h

View File

@ -339,7 +339,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_tfd12_finetune_ar_outputs.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_cheater_gen.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args() args = parser.parse_args()
opt = option.parse(args.opt, is_train=True) opt = option.parse(args.opt, is_train=True)

View File

@ -1,4 +1,5 @@
import functools import functools
import random
import torch import torch
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
@ -44,6 +45,8 @@ class GaussianDiffusionInjector(Injector):
self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], []) self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], [])
self.deterministic_timesteps_every = opt_get(opt, ['deterministic_timesteps_every'], 0) self.deterministic_timesteps_every = opt_get(opt, ['deterministic_timesteps_every'], 0)
self.deterministic_sampler = DeterministicSampler(self.diffusion, opt_get(opt, ['deterministic_sampler_expected_batch_size'], 2048), env) self.deterministic_sampler = DeterministicSampler(self.diffusion, opt_get(opt, ['deterministic_sampler_expected_batch_size'], 2048), env)
self.causal_mode = opt_get(opt, ['causal_mode'], False)
self.causal_slope_range = opt_get(opt, ['causal_slope_range'], [1,8])
k = 0 k = 0
if 'channel_balancer_proportion' in opt.keys(): if 'channel_balancer_proportion' in opt.keys():
@ -86,7 +89,16 @@ class GaussianDiffusionInjector(Injector):
self.deterministic_sampler.reset() # Keep this reset whenever it is not being used, so it is ready to use automatically. self.deterministic_sampler.reset() # Keep this reset whenever it is not being used, so it is ready to use automatically.
model_inputs = {k: state[v] if isinstance(v, str) else v for k, v in self.model_input_keys.items()} model_inputs = {k: state[v] if isinstance(v, str) else v for k, v in self.model_input_keys.items()}
t, weights = sampler.sample(hq.shape[0], hq.device) t, weights = sampler.sample(hq.shape[0], hq.device)
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs, channel_balancing_fn=self.channel_balancing_fn) if self.causal_mode:
cs, ce = self.causal_slope_range
slope = random.random() * (ce-cs) + cs
diffusion_outputs = self.diffusion.causal_training_losses(gen, hq, t, model_kwargs=model_inputs,
channel_balancing_fn=self.channel_balancing_fn,
causal_slope=slope)
else:
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs,
channel_balancing_fn=self.channel_balancing_fn)
if isinstance(sampler, LossAwareSampler): if isinstance(sampler, LossAwareSampler):
sampler.update_with_local_losses(t, diffusion_outputs['loss']) sampler.update_with_local_losses(t, diffusion_outputs['loss'])
if len(self.extra_model_output_keys) > 0: if len(self.extra_model_output_keys) > 0: