From 40427de8e358dc99ba02ab495125717d17156de1 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 20 Jul 2022 21:51:25 -0600 Subject: [PATCH] update tfd13 for inference --- codes/models/audio/music/transformer_diffusion13.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/codes/models/audio/music/transformer_diffusion13.py b/codes/models/audio/music/transformer_diffusion13.py index 292e86f2..43c29428 100644 --- a/codes/models/audio/music/transformer_diffusion13.py +++ b/codes/models/audio/music/transformer_diffusion13.py @@ -210,7 +210,7 @@ class TransformerDiffusion(nn.Module): Args: x: Prediction prior. timesteps: Number of timesteps x has been diffused for. - prior_timesteps: Number of timesteps x_prior has been diffused for. Must be <= timesteps for each batch element. + prior_timesteps: Number of timesteps x_prior has been diffused for. Must be <= timesteps for each batch element. If nothing is specified, then [0] is assumed, e.g. a fully diffused prior. x_prior: A low-resolution prior that guides the model. resolution: Integer indicating the operating resolution level. '0' is the highest resolution. conditioning_input: A semi-related (un-aligned) conditioning input which is used to guide diffusion. Similar to a class input, but hooked to a learned conditioning encoder. @@ -226,6 +226,9 @@ class TransformerDiffusion(nn.Module): self.preprocessed = None else: assert x.shape[-1] > x_prior.shape[-1] * 3.9, f'{x.shape} {x_prior.shape}' + if prior_timesteps is None: + # This is taken to mean a fully diffused prior was given. + prior_timesteps = torch.tensor([0], device=x.device) # Assuming batch_size=1 for inference. x_prior = F.interpolate(x_prior, size=(x.shape[-1],), mode='linear', align_corners=True) assert torch.all(timesteps - prior_timesteps >= 0), f'Prior timesteps should always be lower (more resolved) than input timesteps. {timesteps}, {prior_timesteps}'