fail gracefully from mfd

This commit is contained in:
James Betker 2022-05-27 20:24:16 -06:00
parent f691f5faa1
commit 76aeba7843

View File

@ -117,7 +117,10 @@ class MusicDiffusionFid(evaluator.Evaluator):
mu2 = np.mean(proj2, axis=0)
sigma1 = np.cov(proj1, rowvar=False)
sigma2 = np.cov(proj2, rowvar=False)
try:
return torch.tensor(calculate_frechet_distance(mu1, sigma1, mu2, sigma2))
except:
return 0
def perform_eval(self):
save_path = osp.join(self.env['base_path'], "../", "audio_eval", str(self.env["step"]))
@ -166,13 +169,13 @@ class MusicDiffusionFid(evaluator.Evaluator):
if __name__ == '__main__':
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_waveform_gen.yml', 'generator',
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_tfd.yml', 'generator',
also_load_savepoint=False,
load_path='X:\\dlas\\experiments\\train_music_waveform_gen_reformed_mel\\models\\57500_generator_ema.pth'
load_path='X:\\dlas\\experiments\\train_music_diffusion_tfd\\models\\3000_generator_ema.pth'
).cuda()
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 500,
'conditioning_free': True, 'conditioning_free_k': 1,
'diffusion_schedule': 'linear', 'diffusion_type': 'spec_decode'}
'diffusion_schedule': 'linear', 'diffusion_type': 'from_codes'}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 26, 'device': 'cuda', 'opt': {}}
eval = MusicDiffusionFid(diffusion, opt_eval, env)
print(eval.perform_eval())