diff --git a/codes/models/audio/music/tfdpc_v3.py b/codes/models/audio/music/tfdpc_v3.py index 59531c32..3c150516 100644 --- a/codes/models/audio/music/tfdpc_v3.py +++ b/codes/models/audio/music/tfdpc_v3.py @@ -215,8 +215,8 @@ class TransformerDiffusionWithConditioningEncoder(nn.Module): self.diff = TransformerDiffusionWithPointConditioning(**kwargs) self.conditioning_encoder = ConditioningEncoder(256, kwargs['model_channels']) - def forward(self, x, timesteps, true_cheater, conditioning_input=None, disable_diversity=False, conditioning_free=False): - cond = self.conditioning_encoder(true_cheater) + def forward(self, x, timesteps, conditioning_input=None, disable_diversity=False, conditioning_free=False): + cond = self.conditioning_encoder(conditioning_input) diff = self.diff(x, timesteps, conditioning_input=cond, conditioning_free=conditioning_free) return diff diff --git a/codes/models/audio/music/tfdpc_v5.py b/codes/models/audio/music/tfdpc_v5.py index cf6c5141..09230b99 100644 --- a/codes/models/audio/music/tfdpc_v5.py +++ b/codes/models/audio/music/tfdpc_v5.py @@ -196,16 +196,16 @@ class TransformerDiffusionWithPointConditioning(nn.Module): unused_params = [] time_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim)) - cond_enc = self.conditioning_encoder(conditioning_input, time_emb) - cs = cond_enc[:,:,cond_start] - ce = cond_enc[:,:,x.shape[-1]+cond_start] - cond_enc = torch.cat([cs.unsqueeze(-1), ce.unsqueeze(-1)], dim=-1) - cond_enc = F.interpolate(cond_enc, size=(x.shape[-1],), mode='linear').permute(0,2,1) if conditioning_free: cond = self.unconditioned_embedding + cond = cond.repeat(1,x.shape[-1],1) else: - cond = cond_enc + cond_enc = self.conditioning_encoder(conditioning_input, time_emb) + cs = cond_enc[:,:,cond_start] + ce = cond_enc[:,:,x.shape[-1]+cond_start] + cond_enc = torch.cat([cs.unsqueeze(-1), ce.unsqueeze(-1)], dim=-1) + cond = F.interpolate(cond_enc, size=(x.shape[-1],), mode='linear').permute(0,2,1) # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. if self.training and self.unconditioned_percentage > 0: unconditioned_batches = torch.rand((cond.shape[0], 1, 1), diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py index e29d9409..dfc2681e 100644 --- a/codes/trainer/eval/music_diffusion_fid.py +++ b/codes/trainer/eval/music_diffusion_fid.py @@ -301,14 +301,14 @@ class MusicDiffusionFid(evaluator.Evaluator): if __name__ == '__main__': - diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_cheater_gen.yml', 'generator', + diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_cheater_gen_r8.yml', 'generator', also_load_savepoint=False, - load_path='X:\\dlas\\experiments\\train_music_cheater_gen_v4\\models\\28000_generator_ema.pth' + load_path='X:\\dlas\\experiments\\train_music_cheater_gen_v5\\models\\18000_generator_ema.pth' ).cuda() opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :) #'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety. 'diffusion_steps': 32, - 'conditioning_free': False, 'conditioning_free_k': 1, 'clip_audio': False, 'use_ddim': True, + 'conditioning_free': True, 'conditioning_free_k': 1, 'clip_audio': False, 'use_ddim': True, 'diffusion_schedule': 'linear', 'diffusion_type': 'cheater_gen', #'partial_low': 128, 'partial_high': 192 }