diff --git a/codes/models/audio/music/transformer_diffusion13.py b/codes/models/audio/music/transformer_diffusion13.py index 346a461a..43d2755a 100644 --- a/codes/models/audio/music/transformer_diffusion13.py +++ b/codes/models/audio/music/transformer_diffusion13.py @@ -132,6 +132,7 @@ class TransformerDiffusion(nn.Module): self.enable_fp16 = use_fp16 self.resolution_steps = resolution_steps self.max_window = max_window + self.preprocessed = None self.time_embed = nn.Sequential( linear(time_embed_dim, time_embed_dim), @@ -189,15 +190,20 @@ class TransformerDiffusion(nn.Module): s_prior = x_prior[:,:,start:start+self.max_window] s_prior = F.interpolate(s_prior, scale_factor=.25, mode='linear', align_corners=True) s_prior = F.interpolate(s_prior, size=(s.shape[-1],), mode='linear', align_corners=True) - return s, s_prior, torch.tensor([resolution] * x.shape[0], dtype=torch.long, device=x.device) + self.preprocessed = (s_prior, resolution) + return s def forward(self, x, timesteps, x_prior=None, resolution=None, conditioning_input=None, conditioning_free=False): unused_params = [] + conditioning_input = x_prior if conditioning_input is None else conditioning_input + h = x if resolution is None: - h, h_sub, resolution = self.input_to_random_resolution_and_window(x, x_prior) - else: + assert self.preprocessed is not None, 'Preprocessing function not called.' h = x + h_sub, resolution = self.preprocessed + self.preprocessed = None + else: h_sub = F.interpolate(x_prior, scale_factor=4, mode='linear', align_corners=True) assert h.shape == h_sub.shape, f'{h.shape} {h_sub.shape}' diff --git a/codes/trainer/injectors/gaussian_diffusion_injector.py b/codes/trainer/injectors/gaussian_diffusion_injector.py index 8217043e..931f8fad 100644 --- a/codes/trainer/injectors/gaussian_diffusion_injector.py +++ b/codes/trainer/injectors/gaussian_diffusion_injector.py @@ -47,6 +47,7 @@ class GaussianDiffusionInjector(Injector): self.deterministic_sampler = DeterministicSampler(self.diffusion, opt_get(opt, ['deterministic_sampler_expected_batch_size'], 2048), env) self.causal_mode = opt_get(opt, ['causal_mode'], False) self.causal_slope_range = opt_get(opt, ['causal_slope_range'], [1,8]) + self.preprocess_fn = opt_get(opt, ['preprocess_fn'], None) k = 0 if 'channel_balancer_proportion' in opt.keys(): @@ -88,6 +89,9 @@ class GaussianDiffusionInjector(Injector): sampler = self.schedule_sampler self.deterministic_sampler.reset() # Keep this reset whenever it is not being used, so it is ready to use automatically. model_inputs = {k: state[v] if isinstance(v, str) else v for k, v in self.model_input_keys.items()} + if self.preprocess_fn is not None: + hq = getattr(gen, self.preprocess_fn)(hq, **model_inputs) + t, weights = sampler.sample(hq.shape[0], hq.device) if self.causal_mode: cs, ce = self.causal_slope_range