update tfd13 for inference

This commit is contained in:
James Betker 2022-07-20 21:51:25 -06:00
parent dbebe18602
commit 40427de8e3

View File

@ -210,7 +210,7 @@ class TransformerDiffusion(nn.Module):
Args:
x: Prediction prior.
timesteps: Number of timesteps x has been diffused for.
prior_timesteps: Number of timesteps x_prior has been diffused for. Must be <= timesteps for each batch element.
prior_timesteps: Number of timesteps x_prior has been diffused for. Must be <= timesteps for each batch element. If nothing is specified, then [0] is assumed, e.g. a fully diffused prior.
x_prior: A low-resolution prior that guides the model.
resolution: Integer indicating the operating resolution level. '0' is the highest resolution.
conditioning_input: A semi-related (un-aligned) conditioning input which is used to guide diffusion. Similar to a class input, but hooked to a learned conditioning encoder.
@ -226,6 +226,9 @@ class TransformerDiffusion(nn.Module):
self.preprocessed = None
else:
assert x.shape[-1] > x_prior.shape[-1] * 3.9, f'{x.shape} {x_prior.shape}'
if prior_timesteps is None:
# This is taken to mean a fully diffused prior was given.
prior_timesteps = torch.tensor([0], device=x.device) # Assuming batch_size=1 for inference.
x_prior = F.interpolate(x_prior, size=(x.shape[-1],), mode='linear', align_corners=True)
assert torch.all(timesteps - prior_timesteps >= 0), f'Prior timesteps should always be lower (more resolved) than input timesteps. {timesteps}, {prior_timesteps}'