update tfd13 for inference
This commit is contained in:
parent
dbebe18602
commit
40427de8e3
|
@ -210,7 +210,7 @@ class TransformerDiffusion(nn.Module):
|
|||
Args:
|
||||
x: Prediction prior.
|
||||
timesteps: Number of timesteps x has been diffused for.
|
||||
prior_timesteps: Number of timesteps x_prior has been diffused for. Must be <= timesteps for each batch element.
|
||||
prior_timesteps: Number of timesteps x_prior has been diffused for. Must be <= timesteps for each batch element. If nothing is specified, then [0] is assumed, e.g. a fully diffused prior.
|
||||
x_prior: A low-resolution prior that guides the model.
|
||||
resolution: Integer indicating the operating resolution level. '0' is the highest resolution.
|
||||
conditioning_input: A semi-related (un-aligned) conditioning input which is used to guide diffusion. Similar to a class input, but hooked to a learned conditioning encoder.
|
||||
|
@ -226,6 +226,9 @@ class TransformerDiffusion(nn.Module):
|
|||
self.preprocessed = None
|
||||
else:
|
||||
assert x.shape[-1] > x_prior.shape[-1] * 3.9, f'{x.shape} {x_prior.shape}'
|
||||
if prior_timesteps is None:
|
||||
# This is taken to mean a fully diffused prior was given.
|
||||
prior_timesteps = torch.tensor([0], device=x.device) # Assuming batch_size=1 for inference.
|
||||
x_prior = F.interpolate(x_prior, size=(x.shape[-1],), mode='linear', align_corners=True)
|
||||
assert torch.all(timesteps - prior_timesteps >= 0), f'Prior timesteps should always be lower (more resolved) than input timesteps. {timesteps}, {prior_timesteps}'
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user