diff --git a/codes/models/diffusion/diffusion_dvae.py b/codes/models/diffusion/diffusion_dvae.py index 3d11a873..09977a06 100644 --- a/codes/models/diffusion/diffusion_dvae.py +++ b/codes/models/diffusion/diffusion_dvae.py @@ -122,6 +122,7 @@ class DiffusionDVAE(nn.Module): linear(time_embed_dim, time_embed_dim), ) + self.conditioning_enabled = conditioning_inputs_provided if conditioning_inputs_provided: self.contextual_embedder = AudioMiniEncoder(self.spectrogram_channels, time_embed_dim) self.query_gen = AudioMiniEncoder(decoder_channels[0], time_embed_dim) @@ -249,6 +250,9 @@ class DiffusionDVAE(nn.Module): ) def _decode_continouous(self, x, timesteps, embeddings, conditioning_inputs, num_conditioning_signals): + if self.conditioning_enabled: + assert conditioning_inputs is not None + spec_hs = self.decoder(embeddings)[::-1] # Shape the spectrogram correctly. There is no guarantee it fits (though I probably should add an assertion here to make sure the resizing isn't too wacky.) spec_hs = [nn.functional.interpolate(sh, size=(x.shape[-1]//self.scale_steps**self.spectrogram_conditioning_levels[i],), mode='nearest') for i, sh in enumerate(spec_hs)] @@ -257,7 +261,7 @@ class DiffusionDVAE(nn.Module): # Timestep embeddings and conditioning signals are combined using a small transformer. hs = [] emb1 = self.time_embed(timestep_embedding(timesteps, self.model_channels)) - if conditioning_inputs is not None: + if self.conditioning_enabled: mask = get_mask_from_lengths(num_conditioning_signals+1, conditioning_inputs.shape[1]+1) # +1 to account for the timestep embeddings we'll add. emb2 = torch.stack([self.contextual_embedder(ci.squeeze(1)) for ci in list(torch.chunk(conditioning_inputs, conditioning_inputs.shape[1], dim=1))], dim=1) emb = torch.cat([emb1.unsqueeze(1), emb2], dim=1) diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 1c92c0ce..edabeab0 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -117,7 +117,9 @@ class ExtensibleTrainer(BaseModel): dnet = DistributedDataParallel(anet, delay_allreduce=True) else: from torch.nn.parallel.distributed import DistributedDataParallel - dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()], find_unused_parameters=True) + # Do NOT be tempted to put find_unused_parameters=True here. It will not work in the current incarnation of this trainer. + # Use all of your parameters in training, or delete them! + dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()]) else: dnet = DataParallel(anet, device_ids=opt['gpu_ids']) if self.is_train: