diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py index afe6a221..aaacbfd6 100644 --- a/codes/trainer/eval/music_diffusion_fid.py +++ b/codes/trainer/eval/music_diffusion_fid.py @@ -167,7 +167,10 @@ class MusicDiffusionFid(evaluator.Evaluator): codegen = self.local_modules['codegen'].to(mel.device) codes = codegen.get_codes(mel) mel_norm = normalize_mel(mel) - gen_mel = self.diffuser.p_sample_loop(self.model, mel_norm.shape, model_kwargs={'aligned_conditioning': codes, 'conditioning_input': mel[:,:,:117]}) + precomputed = self.model.timestep_independent(aligned_conditioning=codes, conditioning_input=mel[:,:,:112], + expected_seq_len=mel_norm.shape[-1], return_code_pred=False) + gen_mel = self.diffuser.p_sample_loop(self.model, mel_norm.shape, noise=torch.zeros_like(mel_norm), + model_kwargs={'precomputed_aligned_embeddings': precomputed}) gen_mel_denorm = denormalize_mel(gen_mel) output_shape = (1,16,audio.shape[-1]//16) @@ -243,7 +246,7 @@ class MusicDiffusionFid(evaluator.Evaluator): if __name__ == '__main__': diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_flat.yml', 'generator', also_load_savepoint=False, - load_path='X:\\dlas\\experiments\\train_music_diffusion_flat\\models\\26000_generator.pth').cuda() + load_path='X:\\dlas\\experiments\\train_music_diffusion_flat\\models\\33000_generator_ema.pth').cuda() opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 100, 'conditioning_free': False, 'conditioning_free_k': 1, 'diffusion_schedule': 'linear', 'diffusion_type': 'from_codes'} diff --git a/codes/utils/util.py b/codes/utils/util.py index 5d309be5..a5b83ead 100644 --- a/codes/utils/util.py +++ b/codes/utils/util.py @@ -561,6 +561,7 @@ def find_audio_files(base_path, globs=['*.wav', '*.mp3', '*.ogg', '*.flac']): def load_audio(audiopath, sampling_rate, raw_data=None): + audiopath = str(audiopath) if raw_data is not None: # Assume the data is wav format. SciPy's reader can read raw WAV data from a BytesIO wrapper. audio, lsr = load_wav_to_torch(raw_data)