diff --git a/codes/data/audio/preprocessed_mel_dataset.py b/codes/data/audio/preprocessed_mel_dataset.py index 3c8e813d..8667ed88 100644 --- a/codes/data/audio/preprocessed_mel_dataset.py +++ b/codes/data/audio/preprocessed_mel_dataset.py @@ -20,6 +20,7 @@ class PreprocessedMelDataset(torch.utils.data.Dataset): if os.path.exists(cache_path): self.paths = torch.load(cache_path) else: + print("Building cache..") path = Path(path) self.paths = [str(p) for p in path.rglob("*.npz")] torch.save(self.paths, cache_path) diff --git a/codes/models/audio/music/transformer_diffusion_with_point_conditioning.py b/codes/models/audio/music/transformer_diffusion_with_point_conditioning.py index 444163fd..d37783db 100644 --- a/codes/models/audio/music/transformer_diffusion_with_point_conditioning.py +++ b/codes/models/audio/music/transformer_diffusion_with_point_conditioning.py @@ -216,10 +216,6 @@ class TransformerDiffusionWithConditioningEncoder(nn.Module): self.internal_step = 0 self.diff = TransformerDiffusion(**kwargs) self.conditioning_encoder = ConditioningEncoder(256, kwargs['model_channels']) - self.encoder = UpperEncoder(256, 1024, 256).eval() - for p in self.encoder.parameters(): - p.DO_NOT_TRAIN = True - p.requires_grad = False def forward(self, x, timesteps, true_cheater, conditioning_input=None, disable_diversity=False, conditioning_free=False): cond = self.conditioning_encoder(true_cheater)