diff --git a/codes/trainer/eval/audio_diffusion_fid.py b/codes/trainer/eval/audio_diffusion_fid.py index 6dfbe823..e9c56b45 100644 --- a/codes/trainer/eval/audio_diffusion_fid.py +++ b/codes/trainer/eval/audio_diffusion_fid.py @@ -33,6 +33,8 @@ class AudioDiffusionFid(evaluator.Evaluator): super().__init__(model, opt_eval, env, uses_all_ddp=True) self.real_path = opt_eval['eval_tsv'] self.data = load_tsv_aligned_codes(self.real_path) + if 'clip_dataset' in opt_eval.keys(): + self.data = self.data[:opt_eval['clip_dataset']] if distributed.is_initialized() and distributed.get_world_size() > 1: self.skip = distributed.get_world_size() # One batch element per GPU. else: