Support causal diffusion!
This commit is contained in:
parent
78bba690de
commit
7b4dcbf136
|
@ -81,11 +81,14 @@ class ConditioningEncoder(nn.Module):
|
|||
attn_blocks=6,
|
||||
num_attn_heads=8,
|
||||
dropout=.1,
|
||||
do_checkpointing=False):
|
||||
do_checkpointing=False,
|
||||
time_proj=True):
|
||||
super().__init__()
|
||||
attn = []
|
||||
self.init = nn.Conv1d(cond_dim, embedding_dim, kernel_size=1)
|
||||
self.time_proj = nn.Linear(time_embed_dim, embedding_dim)
|
||||
self.time_proj = time_proj
|
||||
if time_proj:
|
||||
self.time_proj = nn.Linear(time_embed_dim, embedding_dim)
|
||||
self.attn = Encoder(
|
||||
dim=embedding_dim,
|
||||
depth=attn_blocks,
|
||||
|
@ -103,8 +106,9 @@ class ConditioningEncoder(nn.Module):
|
|||
|
||||
def forward(self, x, time_emb):
|
||||
h = self.init(x).permute(0,2,1)
|
||||
time_enc = self.time_proj(time_emb)
|
||||
h = torch.cat([time_enc.unsqueeze(1), h], dim=1)
|
||||
if self.time_proj:
|
||||
time_enc = self.time_proj(time_emb)
|
||||
h = torch.cat([time_enc.unsqueeze(1), h], dim=1)
|
||||
h = self.attn(h).permute(0,2,1)
|
||||
return h
|
||||
|
||||
|
@ -125,6 +129,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
|||
input_cond_dim=1024,
|
||||
num_heads=8,
|
||||
dropout=0,
|
||||
time_proj=False,
|
||||
use_fp16=False,
|
||||
checkpoint_conditioning=True, # This will need to be false for DDP training. :(
|
||||
# Parameters for regularization.
|
||||
|
@ -141,7 +146,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
|||
self.enable_fp16 = use_fp16
|
||||
|
||||
self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1)
|
||||
self.conditioning_encoder = ConditioningEncoder(256, model_channels, time_embed_dim, do_checkpointing=checkpoint_conditioning)
|
||||
self.conditioning_encoder = ConditioningEncoder(256, model_channels, time_embed_dim, do_checkpointing=checkpoint_conditioning, time_proj=time_proj)
|
||||
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
|
@ -210,11 +215,12 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
|||
cond_left = cond_aligned[:,:,:break_pt]
|
||||
cond_right = cond_aligned[:,:,break_pt:]
|
||||
|
||||
# Drop out a random amount of the aligned data. The network will need to figure out how to reconstruct this.
|
||||
to_remove_left = random.randint(1, cond_left.shape[-1]-MIN_MARGIN)
|
||||
cond_left = cond_left[:,:,:-to_remove_left]
|
||||
to_remove_right = random.randint(1, cond_right.shape[-1]-MIN_MARGIN)
|
||||
cond_right = cond_right[:,:,to_remove_right:]
|
||||
if self.training:
|
||||
# Drop out a random amount of the aligned data. The network will need to figure out how to reconstruct this.
|
||||
to_remove_left = random.randint(1, cond_left.shape[-1]-MIN_MARGIN)
|
||||
cond_left = cond_left[:,:,:-to_remove_left]
|
||||
to_remove_right = random.randint(1, cond_right.shape[-1]-MIN_MARGIN)
|
||||
cond_right = cond_right[:,:,to_remove_right:]
|
||||
|
||||
# Concatenate the _pre and _post back on.
|
||||
cond_left_full = torch.cat([cond_pre, cond_left], dim=-1)
|
||||
|
|
|
@ -7,14 +7,15 @@ Docstrings have been added, as well as DDIM sampling and a new collection of bet
|
|||
|
||||
import enum
|
||||
import math
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch as th
|
||||
from tqdm import tqdm
|
||||
|
||||
from .nn import mean_flat
|
||||
from .losses import normal_kl, discretized_gaussian_log_likelihood
|
||||
from models.diffusion.nn import mean_flat
|
||||
from models.diffusion.losses import normal_kl, discretized_gaussian_log_likelihood
|
||||
|
||||
|
||||
def causal_timestep_adjustment(t, S, num_timesteps, causal_slope=1, add_jitter=True):
|
||||
|
@ -37,43 +38,18 @@ def causal_timestep_adjustment(t, S, num_timesteps, causal_slope=1, add_jitter=T
|
|||
# This algorithm for adding causality does so by simply adding S_sloped additional timesteps. To make this
|
||||
# actually work, we map the existing t from the timescale specified to the model to the causal timescale:
|
||||
adj_t = t * (num_timesteps + S_sloped) // num_timesteps
|
||||
adj_t = adj_t - S_sloped
|
||||
if add_jitter:
|
||||
jitter = (random.random() - .5) * S_sloped
|
||||
adj_t = (adj_t+jitter).clamp(0, num_timesteps+S_sloped)
|
||||
t_gap = (num_timesteps + S_sloped) / num_timesteps
|
||||
jitter = (2*random.random()-1) * t_gap
|
||||
adj_t = (adj_t+jitter).clamp(-S_sloped, num_timesteps)
|
||||
|
||||
# Now use the re-mapped adj_t to create a timestep vector that propagates across the sequence with the specified slope.
|
||||
t = adj_t.unsqueeze(1).repeat(1, S)
|
||||
t = (t - torch.arange(0, S) * causal_slope).clamp(0, num_timesteps).long()
|
||||
t = (t + torch.arange(0, S, device=t.device) * causal_slope).clamp(0, num_timesteps).long()
|
||||
return t
|
||||
|
||||
|
||||
def graph_causal_timestep_adjustment():
|
||||
S = 400
|
||||
slope=4
|
||||
num_timesteps=4000
|
||||
#for num_timesteps in range(100, 4000, 200):
|
||||
t_res = []
|
||||
for t in range(num_timesteps, -1, -num_timesteps//50):
|
||||
T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0]
|
||||
t_res.append(T)
|
||||
plt.plot(T.numpy())
|
||||
plt.ylim(0,4000)
|
||||
plt.xlim(0,500)
|
||||
plt.savefig(f'{t}.png')
|
||||
plt.clf()
|
||||
|
||||
for i in range(len(t_res)):
|
||||
for j in range(len(t_res)):
|
||||
if i == j:
|
||||
continue
|
||||
#assert not torch.all(t_res[i] == t_res[j])
|
||||
plt.ylim(0,4000)
|
||||
plt.xlim(0,500)
|
||||
plt.ylabel('timestep')
|
||||
plt.savefig(f'{num_timesteps}.png')
|
||||
plt.clf()
|
||||
|
||||
|
||||
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
|
||||
"""
|
||||
Get a pre-defined beta schedule for the given name.
|
||||
|
@ -319,7 +295,7 @@ class GaussianDiffusion:
|
|||
model_kwargs = {}
|
||||
|
||||
B, C = x.shape[:2]
|
||||
assert t.shape == (B,)
|
||||
assert t.shape == (B,) or t.shape == (B,1,x.shape[-1])
|
||||
model_output = model(x, self._scale_timesteps(t), **model_kwargs)
|
||||
if self.conditioning_free:
|
||||
model_output_no_conditioning = model(x, self._scale_timesteps(t), conditioning_free=True, **model_kwargs)
|
||||
|
@ -844,7 +820,10 @@ class GaussianDiffusion:
|
|||
|
||||
# At the first timestep return the decoder NLL,
|
||||
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
|
||||
output = th.where((t == 0).view(-1, 1, 1), decoder_nll, kl)
|
||||
if len(t.shape) == 1:
|
||||
output = th.where((t == 0).view(-1, 1, 1), decoder_nll, kl)
|
||||
else:
|
||||
output = th.where((t == 0), decoder_nll, kl)
|
||||
return {"output": output, "pred_xstart": out["pred_xstart"]}
|
||||
|
||||
def causal_training_losses(self, model, x_start, t, causal_slope=1, model_kwargs=None, noise=None, channel_balancing_fn=None):
|
||||
|
@ -852,8 +831,9 @@ class GaussianDiffusion:
|
|||
Compute training losses for a causal diffusion process.
|
||||
"""
|
||||
assert len(x_start.shape) == 3, "causal_training_losses assumes a 1d sequence with the axis being the time axis."
|
||||
t = causal_timestep_adjustment(t, x_start.shape[-1], self.num_timesteps, causal_slope, add_jitter=True)
|
||||
return self.training_losses(model, x_start, t, model_kwargs, noise, channel_balancing_fn)
|
||||
ct = causal_timestep_adjustment(t, x_start.shape[-1], self.num_timesteps, causal_slope, add_jitter=True)
|
||||
ct = ct.unsqueeze(1) # Necessary to make the output shape compatible with x_start.
|
||||
return self.training_losses(model, x_start, ct, model_kwargs, noise, channel_balancing_fn)
|
||||
|
||||
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, channel_balancing_fn=None):
|
||||
"""
|
||||
|
@ -872,6 +852,13 @@ class GaussianDiffusion:
|
|||
model_kwargs = {}
|
||||
if noise is None:
|
||||
noise = th.randn_like(x_start)
|
||||
|
||||
if len(t.shape) == 3:
|
||||
t_mask = t != self.num_timesteps
|
||||
t[t_mask.logical_not()] = self.num_timesteps-1
|
||||
else:
|
||||
t_mask = torch.ones_like(x_start)
|
||||
|
||||
x_t = self.q_sample(x_start, t, noise=noise)
|
||||
|
||||
terms = {}
|
||||
|
@ -912,7 +899,7 @@ class GaussianDiffusion:
|
|||
x_t=x_t,
|
||||
t=t,
|
||||
clip_denoised=False,
|
||||
)["output"]
|
||||
)["output"] * t_mask
|
||||
if self.loss_type == LossType.RESCALED_MSE:
|
||||
# Divide by 1000 for equivalence with initial implementation.
|
||||
# Without a factor of 1/1000, the VB term hurts the MSE term.
|
||||
|
@ -932,7 +919,7 @@ class GaussianDiffusion:
|
|||
else:
|
||||
raise NotImplementedError(self.model_mean_type)
|
||||
assert model_output.shape == target.shape == x_start.shape
|
||||
s_err = (target - model_output) ** 2
|
||||
s_err = t_mask * (target - model_output) ** 2
|
||||
if channel_balancing_fn is not None:
|
||||
s_err = channel_balancing_fn(s_err)
|
||||
terms["mse_by_batch"] = s_err.reshape(s_err.shape[0], -1).mean(dim=1)
|
||||
|
@ -1039,3 +1026,47 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
|||
while len(res.shape) < len(broadcast_shape):
|
||||
res = res[..., None]
|
||||
return res.expand(broadcast_shape)
|
||||
|
||||
|
||||
def test_causal_training_losses():
|
||||
from models.diffusion.respace import SpacedDiffusion
|
||||
from models.diffusion.respace import space_timesteps
|
||||
diff = SpacedDiffusion(use_timesteps=space_timesteps(4000, [4000]), model_mean_type='epsilon',
|
||||
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000),
|
||||
conditioning_free=False, conditioning_free_k=1)
|
||||
class IdentityTwoArg(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def forward(self, x, *args, **kwargs):
|
||||
return x.repeat(1,2,1)
|
||||
|
||||
model = IdentityTwoArg()
|
||||
diff.causal_training_losses(model, torch.randn(4,256,400), torch.tensor([500,1000,3000,3500]), causal_slope=4)
|
||||
|
||||
def graph_causal_timestep_adjustment():
|
||||
import matplotlib.pyplot as plt
|
||||
S = 400
|
||||
#slope=4
|
||||
num_timesteps=4000
|
||||
for slpe in range(10, 400, 10):
|
||||
slope = slpe / 10
|
||||
t_res = []
|
||||
for t in range(num_timesteps, -1, -num_timesteps//50):
|
||||
T = causal_timestep_adjustment(torch.tensor([t]), S, num_timesteps, causal_slope=slope, add_jitter=False)[0]
|
||||
t_res.append(T)
|
||||
plt.plot(T.numpy())
|
||||
|
||||
for i in range(len(t_res)):
|
||||
for j in range(len(t_res)):
|
||||
if i == j:
|
||||
continue
|
||||
#assert not torch.all(t_res[i] == t_res[j])
|
||||
plt.ylim(0,num_timesteps)
|
||||
plt.xlim(0,4000)
|
||||
plt.ylabel('timestep')
|
||||
plt.savefig(f'{slpe}.png')
|
||||
plt.clf()
|
||||
|
||||
if __name__ == '__main__':
|
||||
#test_causal_training_losses()
|
||||
graph_causal_timestep_adjustment()
|
|
@ -122,7 +122,10 @@ def timestep_embedding(timesteps, dim, max_period=10000):
|
|||
freqs = th.exp(
|
||||
-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
|
||||
).to(device=timesteps.device)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
if len(timesteps.shape) == 1:
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
else:
|
||||
args = (timesteps.float() * freqs.view(1,half,1)).permute(0,2,1)
|
||||
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
|
||||
|
|
|
@ -365,8 +365,11 @@ class RMSScaleShiftNorm(nn.Module):
|
|||
norm = x / norm.clamp(min=self.eps) * self.g
|
||||
|
||||
ss_emb = self.scale_shift_process(norm_scale_shift_inp)
|
||||
scale, shift = torch.chunk(ss_emb, 2, dim=1)
|
||||
h = norm * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
scale, shift = torch.chunk(ss_emb, 2, dim=-1)
|
||||
if len(scale.shape) == 2:
|
||||
scale = scale.unsqueeze(1)
|
||||
shift = shift.unsqueeze(1)
|
||||
h = norm * (1 + scale) + shift
|
||||
return h
|
||||
|
||||
|
||||
|
|
|
@ -339,7 +339,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_tfd12_finetune_ar_outputs.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_cheater_gen.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import functools
|
||||
import random
|
||||
|
||||
import torch
|
||||
from torch.cuda.amp import autocast
|
||||
|
@ -44,6 +45,8 @@ class GaussianDiffusionInjector(Injector):
|
|||
self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], [])
|
||||
self.deterministic_timesteps_every = opt_get(opt, ['deterministic_timesteps_every'], 0)
|
||||
self.deterministic_sampler = DeterministicSampler(self.diffusion, opt_get(opt, ['deterministic_sampler_expected_batch_size'], 2048), env)
|
||||
self.causal_mode = opt_get(opt, ['causal_mode'], False)
|
||||
self.causal_slope_range = opt_get(opt, ['causal_slope_range'], [1,8])
|
||||
|
||||
k = 0
|
||||
if 'channel_balancer_proportion' in opt.keys():
|
||||
|
@ -86,7 +89,16 @@ class GaussianDiffusionInjector(Injector):
|
|||
self.deterministic_sampler.reset() # Keep this reset whenever it is not being used, so it is ready to use automatically.
|
||||
model_inputs = {k: state[v] if isinstance(v, str) else v for k, v in self.model_input_keys.items()}
|
||||
t, weights = sampler.sample(hq.shape[0], hq.device)
|
||||
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs, channel_balancing_fn=self.channel_balancing_fn)
|
||||
if self.causal_mode:
|
||||
cs, ce = self.causal_slope_range
|
||||
slope = random.random() * (ce-cs) + cs
|
||||
diffusion_outputs = self.diffusion.causal_training_losses(gen, hq, t, model_kwargs=model_inputs,
|
||||
channel_balancing_fn=self.channel_balancing_fn,
|
||||
causal_slope=slope)
|
||||
else:
|
||||
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs,
|
||||
channel_balancing_fn=self.channel_balancing_fn)
|
||||
|
||||
if isinstance(sampler, LossAwareSampler):
|
||||
sampler.update_with_local_losses(t, diffusion_outputs['loss'])
|
||||
if len(self.extra_model_output_keys) > 0:
|
||||
|
|
Loading…
Reference in New Issue
Block a user