fail gracefully from mfd

This commit is contained in:
James Betker 2022-05-27 20:24:16 -06:00
parent f691f5faa1
commit 76aeba7843

View File

@ -117,7 +117,10 @@ class MusicDiffusionFid(evaluator.Evaluator):
mu2 = np.mean(proj2, axis=0) mu2 = np.mean(proj2, axis=0)
sigma1 = np.cov(proj1, rowvar=False) sigma1 = np.cov(proj1, rowvar=False)
sigma2 = np.cov(proj2, rowvar=False) sigma2 = np.cov(proj2, rowvar=False)
try:
return torch.tensor(calculate_frechet_distance(mu1, sigma1, mu2, sigma2)) return torch.tensor(calculate_frechet_distance(mu1, sigma1, mu2, sigma2))
except:
return 0
def perform_eval(self): def perform_eval(self):
save_path = osp.join(self.env['base_path'], "../", "audio_eval", str(self.env["step"])) save_path = osp.join(self.env['base_path'], "../", "audio_eval", str(self.env["step"]))
@ -166,13 +169,13 @@ class MusicDiffusionFid(evaluator.Evaluator):
if __name__ == '__main__': if __name__ == '__main__':
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_waveform_gen.yml', 'generator', diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_tfd.yml', 'generator',
also_load_savepoint=False, also_load_savepoint=False,
load_path='X:\\dlas\\experiments\\train_music_waveform_gen_reformed_mel\\models\\57500_generator_ema.pth' load_path='X:\\dlas\\experiments\\train_music_diffusion_tfd\\models\\3000_generator_ema.pth'
).cuda() ).cuda()
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 500, opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 500,
'conditioning_free': True, 'conditioning_free_k': 1, 'conditioning_free': True, 'conditioning_free_k': 1,
'diffusion_schedule': 'linear', 'diffusion_type': 'spec_decode'} 'diffusion_schedule': 'linear', 'diffusion_type': 'from_codes'}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 26, 'device': 'cuda', 'opt': {}} env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 26, 'device': 'cuda', 'opt': {}}
eval = MusicDiffusionFid(diffusion, opt_eval, env) eval = MusicDiffusionFid(diffusion, opt_eval, env)
print(eval.perform_eval()) print(eval.perform_eval())