From dbebe1860221af3967af6eab142b9d3f307f1fb7 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 20 Jul 2022 12:12:33 -0600 Subject: [PATCH] Fix ts=0 with new formulation --- codes/models/audio/music/transformer_diffusion13.py | 9 +++++---- codes/models/diffusion/gaussian_diffusion.py | 10 ++++++++-- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/codes/models/audio/music/transformer_diffusion13.py b/codes/models/audio/music/transformer_diffusion13.py index fe47cef9..292e86f2 100644 --- a/codes/models/audio/music/transformer_diffusion13.py +++ b/codes/models/audio/music/transformer_diffusion13.py @@ -192,8 +192,9 @@ class TransformerDiffusion(nn.Module): # Now diffuse the prior randomly between the x timestep and 0. adv = torch.rand_like(ts.float()) - t_prior = (adv * ts).long() - s_prior_diffused = diffuser.q_sample(s_prior, t_prior, torch.randn_like(s_prior)) + t_prior = (adv * ts).long() - 1 + # The t_prior-1 below is an important detail: it forces s_prior to be unmodified for ts=0. It also means that t_prior is not on the same timescale as ts (instead it is shifted by 1). + s_prior_diffused = diffuser.q_sample(s_prior, t_prior-1, torch.randn_like(s_prior), allow_negatives=True) self.preprocessed = (s_prior_diffused, t_prior, torch.tensor([resolution] * x.shape[0], dtype=torch.long, device=x.device)) return s @@ -226,7 +227,7 @@ class TransformerDiffusion(nn.Module): else: assert x.shape[-1] > x_prior.shape[-1] * 3.9, f'{x.shape} {x_prior.shape}' x_prior = F.interpolate(x_prior, size=(x.shape[-1],), mode='linear', align_corners=True) - assert torch.all(timesteps - prior_timesteps > 0), f'Prior timesteps should always be lower (more resolved) than input timesteps. {timesteps}, {prior_timesteps}' + assert torch.all(timesteps - prior_timesteps >= 0), f'Prior timesteps should always be lower (more resolved) than input timesteps. {timesteps}, {prior_timesteps}' if conditioning_free: code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1) @@ -284,7 +285,7 @@ def test_tfd(): betas=get_named_beta_schedule('linear', 4000)) clip = torch.randn(2,256,10336) cond = torch.randn(2,256,10336) - ts = torch.LongTensor([600, 600]) + ts = torch.LongTensor([0, 0]) model = TransformerDiffusion(in_channels=256, model_channels=1024, contraction_dim=512, num_heads=512//64, input_vec_dim=256, num_layers=12, dropout=.1, unconditioned_percentage=.6) diff --git a/codes/models/diffusion/gaussian_diffusion.py b/codes/models/diffusion/gaussian_diffusion.py index c20b7b30..4a57ebb7 100644 --- a/codes/models/diffusion/gaussian_diffusion.py +++ b/codes/models/diffusion/gaussian_diffusion.py @@ -234,7 +234,7 @@ class GaussianDiffusion: ) return mean, variance, log_variance - def q_sample(self, x_start, t, noise=None): + def q_sample(self, x_start, t, noise=None, allow_negatives=False): """ Diffuse the data for a given number of diffusion steps. @@ -248,11 +248,17 @@ class GaussianDiffusion: if noise is None: noise = th.randn_like(x_start) assert noise.shape == x_start.shape - return ( + if allow_negatives: + mask = (t < 0) + t[mask] = 0 + result = ( _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) + if allow_negatives: + result[mask] = x_start[mask] + return result def q_posterior_mean_variance(self, x_start, x_t, t): """