diff --git a/codes/models/audio/music/transformer_diffusion13.py b/codes/models/audio/music/transformer_diffusion13.py index 292e86f2..43c29428 100644 --- a/codes/models/audio/music/transformer_diffusion13.py +++ b/codes/models/audio/music/transformer_diffusion13.py @@ -210,7 +210,7 @@ class TransformerDiffusion(nn.Module): Args: x: Prediction prior. timesteps: Number of timesteps x has been diffused for. - prior_timesteps: Number of timesteps x_prior has been diffused for. Must be <= timesteps for each batch element. + prior_timesteps: Number of timesteps x_prior has been diffused for. Must be <= timesteps for each batch element. If nothing is specified, then [0] is assumed, e.g. a fully diffused prior. x_prior: A low-resolution prior that guides the model. resolution: Integer indicating the operating resolution level. '0' is the highest resolution. conditioning_input: A semi-related (un-aligned) conditioning input which is used to guide diffusion. Similar to a class input, but hooked to a learned conditioning encoder. @@ -226,6 +226,9 @@ class TransformerDiffusion(nn.Module): self.preprocessed = None else: assert x.shape[-1] > x_prior.shape[-1] * 3.9, f'{x.shape} {x_prior.shape}' + if prior_timesteps is None: + # This is taken to mean a fully diffused prior was given. + prior_timesteps = torch.tensor([0], device=x.device) # Assuming batch_size=1 for inference. x_prior = F.interpolate(x_prior, size=(x.shape[-1],), mode='linear', align_corners=True) assert torch.all(timesteps - prior_timesteps >= 0), f'Prior timesteps should always be lower (more resolved) than input timesteps. {timesteps}, {prior_timesteps}'