forked from mrq/DL-Art-School
Fix ts=0 with new formulation
This commit is contained in:
parent
82bd62019f
commit
dbebe18602
|
@ -192,8 +192,9 @@ class TransformerDiffusion(nn.Module):
|
||||||
|
|
||||||
# Now diffuse the prior randomly between the x timestep and 0.
|
# Now diffuse the prior randomly between the x timestep and 0.
|
||||||
adv = torch.rand_like(ts.float())
|
adv = torch.rand_like(ts.float())
|
||||||
t_prior = (adv * ts).long()
|
t_prior = (adv * ts).long() - 1
|
||||||
s_prior_diffused = diffuser.q_sample(s_prior, t_prior, torch.randn_like(s_prior))
|
# The t_prior-1 below is an important detail: it forces s_prior to be unmodified for ts=0. It also means that t_prior is not on the same timescale as ts (instead it is shifted by 1).
|
||||||
|
s_prior_diffused = diffuser.q_sample(s_prior, t_prior-1, torch.randn_like(s_prior), allow_negatives=True)
|
||||||
|
|
||||||
self.preprocessed = (s_prior_diffused, t_prior, torch.tensor([resolution] * x.shape[0], dtype=torch.long, device=x.device))
|
self.preprocessed = (s_prior_diffused, t_prior, torch.tensor([resolution] * x.shape[0], dtype=torch.long, device=x.device))
|
||||||
return s
|
return s
|
||||||
|
@ -226,7 +227,7 @@ class TransformerDiffusion(nn.Module):
|
||||||
else:
|
else:
|
||||||
assert x.shape[-1] > x_prior.shape[-1] * 3.9, f'{x.shape} {x_prior.shape}'
|
assert x.shape[-1] > x_prior.shape[-1] * 3.9, f'{x.shape} {x_prior.shape}'
|
||||||
x_prior = F.interpolate(x_prior, size=(x.shape[-1],), mode='linear', align_corners=True)
|
x_prior = F.interpolate(x_prior, size=(x.shape[-1],), mode='linear', align_corners=True)
|
||||||
assert torch.all(timesteps - prior_timesteps > 0), f'Prior timesteps should always be lower (more resolved) than input timesteps. {timesteps}, {prior_timesteps}'
|
assert torch.all(timesteps - prior_timesteps >= 0), f'Prior timesteps should always be lower (more resolved) than input timesteps. {timesteps}, {prior_timesteps}'
|
||||||
|
|
||||||
if conditioning_free:
|
if conditioning_free:
|
||||||
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1)
|
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1)
|
||||||
|
@ -284,7 +285,7 @@ def test_tfd():
|
||||||
betas=get_named_beta_schedule('linear', 4000))
|
betas=get_named_beta_schedule('linear', 4000))
|
||||||
clip = torch.randn(2,256,10336)
|
clip = torch.randn(2,256,10336)
|
||||||
cond = torch.randn(2,256,10336)
|
cond = torch.randn(2,256,10336)
|
||||||
ts = torch.LongTensor([600, 600])
|
ts = torch.LongTensor([0, 0])
|
||||||
model = TransformerDiffusion(in_channels=256, model_channels=1024, contraction_dim=512,
|
model = TransformerDiffusion(in_channels=256, model_channels=1024, contraction_dim=512,
|
||||||
num_heads=512//64, input_vec_dim=256, num_layers=12, dropout=.1,
|
num_heads=512//64, input_vec_dim=256, num_layers=12, dropout=.1,
|
||||||
unconditioned_percentage=.6)
|
unconditioned_percentage=.6)
|
||||||
|
|
|
@ -234,7 +234,7 @@ class GaussianDiffusion:
|
||||||
)
|
)
|
||||||
return mean, variance, log_variance
|
return mean, variance, log_variance
|
||||||
|
|
||||||
def q_sample(self, x_start, t, noise=None):
|
def q_sample(self, x_start, t, noise=None, allow_negatives=False):
|
||||||
"""
|
"""
|
||||||
Diffuse the data for a given number of diffusion steps.
|
Diffuse the data for a given number of diffusion steps.
|
||||||
|
|
||||||
|
@ -248,11 +248,17 @@ class GaussianDiffusion:
|
||||||
if noise is None:
|
if noise is None:
|
||||||
noise = th.randn_like(x_start)
|
noise = th.randn_like(x_start)
|
||||||
assert noise.shape == x_start.shape
|
assert noise.shape == x_start.shape
|
||||||
return (
|
if allow_negatives:
|
||||||
|
mask = (t < 0)
|
||||||
|
t[mask] = 0
|
||||||
|
result = (
|
||||||
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||||
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
|
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
|
||||||
* noise
|
* noise
|
||||||
)
|
)
|
||||||
|
if allow_negatives:
|
||||||
|
result[mask] = x_start[mask]
|
||||||
|
return result
|
||||||
|
|
||||||
def q_posterior_mean_variance(self, x_start, x_t, t):
|
def q_posterior_mean_variance(self, x_start, x_t, t):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue
Block a user