2022-06-10 03:41:20 +00:00
import functools
2022-04-20 06:28:03 +00:00
import os
import os . path as osp
from glob import glob
2022-05-16 03:50:54 +00:00
from random import shuffle
2022-05-23 16:37:15 +00:00
from time import time
2022-04-20 06:28:03 +00:00
2022-05-09 00:49:39 +00:00
import numpy as np
2022-04-20 06:28:03 +00:00
import torch
import torchaudio
2022-05-02 05:04:56 +00:00
import torchvision
2022-04-20 06:28:03 +00:00
from pytorch_fid . fid_score import calculate_frechet_distance
from torch import distributed
from tqdm import tqdm
import trainer . eval . evaluator as evaluator
from data . audio . unsupervised_audio_dataset import load_audio
2022-05-22 11:23:54 +00:00
from models . audio . mel2vec import ContrastiveTrainingWrapper
2022-05-06 20:33:44 +00:00
from models . audio . music . unet_diffusion_waveform_gen import DiffusionWaveformGen
2022-05-07 03:56:49 +00:00
from models . clip . contrastive_audio import ContrastiveAudio
2022-04-20 06:28:03 +00:00
from models . diffusion . gaussian_diffusion import get_named_beta_schedule
from models . diffusion . respace import space_timesteps , SpacedDiffusion
2022-05-06 20:33:44 +00:00
from trainer . injectors . audio_injectors import denormalize_mel , TorchMelSpectrogramInjector , pixel_shuffle_1d , \
2022-07-04 14:38:47 +00:00
normalize_mel , KmeansQuantizerInjector
2022-06-27 16:11:23 +00:00
from utils . music_utils import get_music_codegen , get_mel2wav_model , get_cheater_decoder , get_cheater_encoder , \
2022-07-04 14:38:47 +00:00
get_mel2wav_v3_model , get_ar_prior
2022-05-09 00:49:39 +00:00
from utils . util import opt_get , load_model_from_config
2022-04-20 06:28:03 +00:00
class MusicDiffusionFid ( evaluator . Evaluator ) :
"""
Evaluator produces generate from a music diffusion model .
"""
def __init__ ( self , model , opt_eval , env ) :
super ( ) . __init__ ( model , opt_eval , env , uses_all_ddp = True )
self . real_path = opt_eval [ ' path ' ]
self . data = self . load_data ( self . real_path )
2022-06-27 16:11:23 +00:00
self . clip = opt_get ( opt_eval , [ ' clip_audio ' ] , True ) # Recommend setting true for more efficient eval passes.
self . ddim = opt_get ( opt_eval , [ ' use_ddim ' ] , False )
2022-04-20 06:28:03 +00:00
if distributed . is_initialized ( ) and distributed . get_world_size ( ) > 1 :
self . skip = distributed . get_world_size ( ) # One batch element per GPU.
else :
self . skip = 1
diffusion_steps = opt_get ( opt_eval , [ ' diffusion_steps ' ] , 50 )
diffusion_schedule = opt_get ( env [ ' opt ' ] , [ ' steps ' , ' generator ' , ' injectors ' , ' diffusion ' , ' beta_schedule ' , ' schedule_name ' ] , None )
if diffusion_schedule is None :
print ( " Unable to infer diffusion schedule from master options. Getting it from eval (or guessing). " )
diffusion_schedule = opt_get ( opt_eval , [ ' diffusion_schedule ' ] , ' linear ' )
conditioning_free_diffusion_enabled = opt_get ( opt_eval , [ ' conditioning_free ' ] , False )
conditioning_free_k = opt_get ( opt_eval , [ ' conditioning_free_k ' ] , 1 )
self . diffuser = SpacedDiffusion ( use_timesteps = space_timesteps ( 4000 , [ diffusion_steps ] ) , model_mean_type = ' epsilon ' ,
model_var_type = ' learned_range ' , loss_type = ' mse ' , betas = get_named_beta_schedule ( diffusion_schedule , 4000 ) ,
conditioning_free = conditioning_free_diffusion_enabled , conditioning_free_k = conditioning_free_k )
2022-06-27 16:11:23 +00:00
self . spectral_diffuser = SpacedDiffusion ( use_timesteps = space_timesteps ( 4000 , [ 16 if self . ddim else 100 ] ) , model_mean_type = ' epsilon ' ,
2022-06-12 20:41:06 +00:00
model_var_type = ' learned_range ' , loss_type = ' mse ' , betas = get_named_beta_schedule ( ' linear ' , 4000 ) ,
conditioning_free = False , conditioning_free_k = 1 )
2022-04-20 06:28:03 +00:00
self . dev = self . env [ ' device ' ]
2022-06-18 16:40:48 +00:00
mode = opt_get ( opt_eval , [ ' diffusion_type ' ] , ' spec_decode ' )
2022-05-06 20:33:44 +00:00
2022-05-07 03:56:49 +00:00
self . projector = ContrastiveAudio ( model_dim = 512 , transformer_heads = 8 , dropout = 0 , encoder_depth = 8 , mel_channels = 256 )
2022-05-09 15:19:26 +00:00
self . projector . load_state_dict ( torch . load ( ' ../experiments/music_eval_projector.pth ' , map_location = torch . device ( ' cpu ' ) ) )
2022-05-06 20:33:44 +00:00
2022-06-27 16:11:23 +00:00
self . local_modules = { ' projector ' : self . projector }
2022-05-06 20:33:44 +00:00
if mode == ' spec_decode ' :
self . diffusion_fn = self . perform_diffusion_spec_decode
2022-06-20 03:04:51 +00:00
self . squeeze_ratio = opt_eval [ ' squeeze_ratio ' ]
2022-05-22 11:23:54 +00:00
elif ' from_codes ' == mode :
self . diffusion_fn = self . perform_diffusion_from_codes
self . local_modules [ ' codegen ' ] = get_music_codegen ( )
2022-05-30 22:25:33 +00:00
elif ' from_codes_quant ' == mode :
self . diffusion_fn = self . perform_diffusion_from_codes_quant
2022-06-10 03:41:20 +00:00
elif ' partial_from_codes_quant ' == mode :
self . diffusion_fn = functools . partial ( self . perform_partial_diffusion_from_codes_quant ,
partial_low = opt_eval [ ' partial_low ' ] ,
partial_high = opt_eval [ ' partial_high ' ] )
2022-06-10 03:14:48 +00:00
elif ' from_codes_quant_gradual_decode ' == mode :
self . diffusion_fn = self . perform_diffusion_from_codes_quant_gradual_decode
2022-06-27 16:11:23 +00:00
elif ' cheater_gen ' == mode :
self . diffusion_fn = self . perform_reconstruction_from_cheater_gen
self . local_modules [ ' cheater_encoder ' ] = get_cheater_encoder ( )
self . local_modules [ ' cheater_decoder ' ] = get_cheater_decoder ( )
2022-06-28 23:29:21 +00:00
self . cheater_decoder_diffuser = SpacedDiffusion ( use_timesteps = space_timesteps ( 4000 , [ 32 ] ) , model_mean_type = ' epsilon ' ,
model_var_type = ' learned_range ' , loss_type = ' mse ' , betas = get_named_beta_schedule ( ' linear ' , 4000 ) ,
conditioning_free = True , conditioning_free_k = 1 )
2022-06-27 16:11:23 +00:00
self . spec_decoder = get_mel2wav_v3_model ( ) # The only reason the other functions don't use v3 is because earlier models were trained with v1 and I want to keep metrics consistent.
self . local_modules [ ' spec_decoder ' ] = self . spec_decoder
2022-07-04 14:38:47 +00:00
elif ' from_ar_prior ' == mode :
self . diffusion_fn = self . perform_diffusion_from_codes_ar_prior
self . local_modules [ ' cheater_encoder ' ] = get_cheater_encoder ( )
self . kmeans_inj = KmeansQuantizerInjector ( { ' centroids ' : ' ../experiments/music_k_means_centroids.pth ' , ' in ' : ' in ' , ' out ' : ' out ' } , { } )
self . local_modules [ ' ar_prior ' ] = get_ar_prior ( )
self . spec_decoder = get_mel2wav_v3_model ( )
self . local_modules [ ' spec_decoder ' ] = self . spec_decoder
2022-06-27 16:11:23 +00:00
if not hasattr ( self , ' spec_decoder ' ) :
self . spec_decoder = get_mel2wav_model ( )
self . local_modules [ ' spec_decoder ' ] = self . spec_decoder
2022-05-23 16:38:28 +00:00
self . spec_fn = TorchMelSpectrogramInjector ( { ' n_mel_channels ' : 256 , ' mel_fmax ' : 11000 , ' filter_length ' : 16000 ,
2022-05-24 20:02:33 +00:00
' normalize ' : True , ' in ' : ' in ' , ' out ' : ' out ' } , { } )
2022-04-20 06:28:03 +00:00
def load_data ( self , path ) :
return list ( glob ( f ' { path } /*.wav ' ) )
2022-05-06 20:33:44 +00:00
def perform_diffusion_spec_decode ( self , audio , sample_rate = 22050 ) :
2022-06-19 21:07:24 +00:00
real_resampled = audio
2022-04-28 16:08:55 +00:00
audio = audio . unsqueeze ( 0 )
2022-06-20 03:04:51 +00:00
output_shape = ( 1 , self . squeeze_ratio , audio . shape [ - 1 ] / / self . squeeze_ratio )
2022-05-02 05:04:56 +00:00
mel = self . spec_fn ( { ' in ' : audio } ) [ ' out ' ]
2022-05-23 16:37:15 +00:00
gen = self . diffuser . p_sample_loop ( self . model , output_shape ,
2022-06-19 21:07:24 +00:00
model_kwargs = { ' codes ' : mel } )
2022-06-20 03:04:51 +00:00
gen = pixel_shuffle_1d ( gen , self . squeeze_ratio )
2022-05-06 20:33:44 +00:00
2022-05-24 20:02:33 +00:00
return gen , real_resampled , normalize_mel ( self . spec_fn ( { ' in ' : gen } ) [ ' out ' ] ) , normalize_mel ( mel ) , sample_rate
2022-05-06 20:33:44 +00:00
2022-05-22 11:23:54 +00:00
def perform_diffusion_from_codes ( self , audio , sample_rate = 22050 ) :
2022-06-19 21:07:24 +00:00
real_resampled = audio
2022-05-22 11:23:54 +00:00
audio = audio . unsqueeze ( 0 )
mel = self . spec_fn ( { ' in ' : audio } ) [ ' out ' ]
codegen = self . local_modules [ ' codegen ' ] . to ( mel . device )
2022-05-29 05:19:36 +00:00
codes = codegen . get_codes ( mel , project = True )
2022-05-23 05:10:58 +00:00
mel_norm = normalize_mel ( mel )
2022-05-27 17:40:47 +00:00
gen_mel = self . diffuser . p_sample_loop ( self . model , mel_norm . shape ,
2022-05-30 22:25:33 +00:00
model_kwargs = { ' codes ' : codes , ' conditioning_input ' : torch . zeros_like ( mel_norm [ : , : , : 390 ] ) } )
2022-05-23 05:10:58 +00:00
gen_mel_denorm = denormalize_mel ( gen_mel )
output_shape = ( 1 , 16 , audio . shape [ - 1 ] / / 16 )
self . spec_decoder = self . spec_decoder . to ( audio . device )
2022-06-12 20:41:06 +00:00
gen_wav = self . spectral_diffuser . p_sample_loop ( self . spec_decoder , output_shape ,
2022-05-30 22:25:33 +00:00
model_kwargs = { ' aligned_conditioning ' : gen_mel_denorm } )
gen_wav = pixel_shuffle_1d ( gen_wav , 16 )
return gen_wav , real_resampled , gen_mel , mel_norm , sample_rate
def perform_diffusion_from_codes_quant ( self , audio , sample_rate = 22050 ) :
2022-06-19 21:07:24 +00:00
real_resampled = audio
2022-05-30 22:25:33 +00:00
audio = audio . unsqueeze ( 0 )
mel = self . spec_fn ( { ' in ' : audio } ) [ ' out ' ]
mel_norm = normalize_mel ( mel )
2022-06-04 16:15:31 +00:00
#def denoising_fn(x):
# q9 = torch.quantile(x, q=.95, dim=-1).unsqueeze(-1)
# s = q9.clamp(1, 9999999999)
# x = x.clamp(-s, s) / s
# return x
gen_mel = self . diffuser . p_sample_loop ( self . model , mel_norm . shape , #denoised_fn=denoising_fn, clip_denoised=False,
2022-06-12 20:41:06 +00:00
model_kwargs = { ' truth_mel ' : mel_norm ,
2022-06-12 20:51:56 +00:00
' conditioning_input ' : mel_norm ,
2022-06-01 22:35:15 +00:00
' disable_diversity ' : True } )
2022-05-30 22:25:33 +00:00
gen_mel_denorm = denormalize_mel ( gen_mel )
output_shape = ( 1 , 16 , audio . shape [ - 1 ] / / 16 )
self . spec_decoder = self . spec_decoder . to ( audio . device )
2022-06-12 20:41:06 +00:00
gen_wav = self . spectral_diffuser . p_sample_loop ( self . spec_decoder , output_shape ,
2022-05-30 22:25:33 +00:00
model_kwargs = { ' aligned_conditioning ' : gen_mel_denorm } )
2022-05-23 05:10:58 +00:00
gen_wav = pixel_shuffle_1d ( gen_wav , 16 )
2022-06-18 16:40:48 +00:00
real_wav = self . spectral_diffuser . p_sample_loop ( self . spec_decoder , output_shape ,
model_kwargs = { ' aligned_conditioning ' : mel } )
real_wav = pixel_shuffle_1d ( real_wav , 16 )
return gen_wav , real_wav . squeeze ( 0 ) , gen_mel , mel_norm , sample_rate
2022-05-23 05:10:58 +00:00
2022-06-10 03:41:20 +00:00
def perform_partial_diffusion_from_codes_quant ( self , audio , sample_rate = 22050 , partial_low = 0 , partial_high = 256 ) :
2022-06-19 21:07:24 +00:00
real_resampled = audio
2022-06-10 03:41:20 +00:00
audio = audio . unsqueeze ( 0 )
mel = self . spec_fn ( { ' in ' : audio } ) [ ' out ' ]
mel_norm = normalize_mel ( mel )
mask = torch . ones_like ( mel_norm )
mask [ : , partial_low : partial_high ] = 0 # This is the channel region that the model will predict.
gen_mel = self . diffuser . p_sample_loop_with_guidance ( self . model ,
guidance_input = mel_norm , mask = mask ,
model_kwargs = { ' truth_mel ' : mel ,
' conditioning_input ' : torch . zeros_like ( mel_norm [ : , : , : 390 ] ) ,
' disable_diversity ' : True } )
gen_mel_denorm = denormalize_mel ( gen_mel )
output_shape = ( 1 , 16 , audio . shape [ - 1 ] / / 16 )
self . spec_decoder = self . spec_decoder . to ( audio . device )
2022-06-12 20:41:06 +00:00
gen_wav = self . spectral_diffuser . p_sample_loop ( self . spec_decoder , output_shape ,
2022-06-10 03:41:20 +00:00
model_kwargs = { ' aligned_conditioning ' : gen_mel_denorm } )
gen_wav = pixel_shuffle_1d ( gen_wav , 16 )
return gen_wav , real_resampled , gen_mel , mel_norm , sample_rate
2022-06-10 03:14:48 +00:00
def perform_diffusion_from_codes_quant_gradual_decode ( self , audio , sample_rate = 22050 ) :
2022-06-19 21:07:24 +00:00
real_resampled = audio
2022-06-10 03:14:48 +00:00
audio = audio . unsqueeze ( 0 )
mel = self . spec_fn ( { ' in ' : audio } ) [ ' out ' ]
mel_norm = normalize_mel ( mel )
guidance = torch . zeros_like ( mel_norm )
mask = torch . zeros_like ( mel_norm )
GRADS = 4
for k in range ( GRADS ) :
gen_mel = self . diffuser . p_sample_loop_with_guidance ( self . model ,
guidance_input = guidance , mask = mask ,
model_kwargs = { ' truth_mel ' : mel ,
' conditioning_input ' : torch . zeros_like ( mel_norm [ : , : , : 390 ] ) ,
' disable_diversity ' : True } )
pk = int ( k * ( mel_norm . shape [ 1 ] / GRADS ) )
ek = int ( ( k + 1 ) * ( mel_norm . shape [ 1 ] / GRADS ) )
guidance [ : , pk : ek ] = gen_mel [ : , pk : ek ]
mask [ : , : ek ] = 1
gen_mel_denorm = denormalize_mel ( gen_mel )
output_shape = ( 1 , 16 , audio . shape [ - 1 ] / / 16 )
self . spec_decoder = self . spec_decoder . to ( audio . device )
gen_wav = self . diffuser . p_sample_loop ( self . spec_decoder , output_shape ,
model_kwargs = { ' aligned_conditioning ' : gen_mel_denorm } )
gen_wav = pixel_shuffle_1d ( gen_wav , 16 )
return gen_wav , real_resampled , gen_mel , mel_norm , sample_rate
2022-06-27 16:11:23 +00:00
def perform_reconstruction_from_cheater_gen ( self , audio , sample_rate = 22050 ) :
assert self . ddim , " DDIM mode expected for reconstructing cheater gen. Do you like to waste resources?? "
audio = audio . unsqueeze ( 0 )
mel = self . spec_fn ( { ' in ' : audio } ) [ ' out ' ]
mel_norm = normalize_mel ( mel )
cheater = self . local_modules [ ' cheater_encoder ' ] . to ( audio . device ) ( mel_norm )
# 1. Generate the cheater latent using the input as a reference.
gen_cheater = self . diffuser . ddim_sample_loop ( self . model , cheater . shape , progress = True , model_kwargs = { ' conditioning_input ' : cheater } )
# 2. Decode the cheater into a MEL
2022-06-28 23:29:21 +00:00
gen_mel = self . cheater_decoder_diffuser . ddim_sample_loop ( self . local_modules [ ' cheater_decoder ' ] . diff . to ( audio . device ) , ( 1 , 256 , gen_cheater . shape [ - 1 ] * 16 ) , progress = True ,
2022-06-27 16:11:23 +00:00
model_kwargs = { ' codes ' : gen_cheater . permute ( 0 , 2 , 1 ) } )
# 3. And then the MEL back into a spectrogram
output_shape = ( 1 , 16 , audio . shape [ - 1 ] / / 16 )
self . spec_decoder = self . spec_decoder . to ( audio . device )
gen_mel_denorm = denormalize_mel ( gen_mel )
gen_wav = self . spectral_diffuser . p_sample_loop ( self . spec_decoder , output_shape ,
model_kwargs = { ' codes ' : gen_mel_denorm } )
gen_wav = pixel_shuffle_1d ( gen_wav , 16 )
real_wav = self . spectral_diffuser . p_sample_loop ( self . spec_decoder , output_shape ,
model_kwargs = { ' codes ' : mel } )
real_wav = pixel_shuffle_1d ( real_wav , 16 )
return gen_wav , real_wav . squeeze ( 0 ) , gen_mel , mel_norm , sample_rate
2022-07-04 14:38:47 +00:00
def perform_diffusion_from_codes_ar_prior ( self , audio , sample_rate = 22050 ) :
assert self . ddim , " DDIM mode expected for reconstructing cheater gen. Do you like to waste resources?? "
audio = audio . unsqueeze ( 0 )
mel = self . spec_fn ( { ' in ' : audio } ) [ ' out ' ]
mel_norm = normalize_mel ( mel )
cheater = self . local_modules [ ' cheater_encoder ' ] . to ( audio . device ) ( mel_norm )
cheater_codes = self . kmeans_inj ( { ' in ' : cheater } ) [ ' out ' ]
ar_latent = self . local_modules [ ' ar_prior ' ] . to ( audio . device ) ( cheater_codes , cheater , return_latent = True )
gen_mel = self . diffuser . ddim_sample_loop ( self . model , mel_norm . shape , model_kwargs = { ' codes ' : ar_latent } , progress = True )
gen_mel_denorm = denormalize_mel ( gen_mel )
output_shape = ( 1 , 16 , audio . shape [ - 1 ] / / 16 )
self . spec_decoder = self . spec_decoder . to ( audio . device )
gen_wav = self . spectral_diffuser . ddim_sample_loop ( self . spec_decoder , output_shape ,
model_kwargs = { ' codes ' : gen_mel_denorm } )
gen_wav = pixel_shuffle_1d ( gen_wav , 16 )
real_wav = self . spectral_diffuser . ddim_sample_loop ( self . spec_decoder , output_shape ,
model_kwargs = { ' codes ' : mel } )
real_wav = pixel_shuffle_1d ( real_wav , 16 )
return gen_wav , real_wav . squeeze ( 0 ) , gen_mel , mel_norm , sample_rate
2022-05-22 11:23:54 +00:00
2022-05-07 03:56:49 +00:00
def project ( self , sample , sample_rate ) :
2022-04-20 06:28:03 +00:00
sample = torchaudio . functional . resample ( sample , sample_rate , 22050 )
2022-05-07 03:56:49 +00:00
mel = self . spec_fn ( { ' in ' : sample } ) [ ' out ' ]
projection = self . projector . project ( mel )
return projection . squeeze ( 0 ) # Getting rid of the batch dimension means it's just [hidden_dim]
2022-04-20 06:28:03 +00:00
def compute_frechet_distance ( self , proj1 , proj2 ) :
# I really REALLY FUCKING HATE that this is going to numpy. Why does "pytorch_fid" operate in numpy land. WHY?
proj1 = proj1 . cpu ( ) . numpy ( )
proj2 = proj2 . cpu ( ) . numpy ( )
mu1 = np . mean ( proj1 , axis = 0 )
mu2 = np . mean ( proj2 , axis = 0 )
sigma1 = np . cov ( proj1 , rowvar = False )
sigma2 = np . cov ( proj2 , rowvar = False )
2022-05-28 02:24:16 +00:00
try :
return torch . tensor ( calculate_frechet_distance ( mu1 , sigma1 , mu2 , sigma2 ) )
except :
return 0
2022-04-20 06:28:03 +00:00
def perform_eval ( self ) :
save_path = osp . join ( self . env [ ' base_path ' ] , " ../ " , " audio_eval " , str ( self . env [ " step " ] ) )
os . makedirs ( save_path , exist_ok = True )
2022-05-07 03:56:49 +00:00
self . projector = self . projector . to ( self . dev )
self . projector . eval ( )
2022-04-20 06:28:03 +00:00
# Attempt to fix the random state as much as possible. RNG state will be restored before returning.
rng_state = torch . get_rng_state ( )
torch . manual_seed ( 5 )
self . model . eval ( )
with torch . no_grad ( ) :
gen_projections = [ ]
real_projections = [ ]
for i in tqdm ( list ( range ( 0 , len ( self . data ) , self . skip ) ) ) :
2022-06-01 22:35:15 +00:00
path = self . data [ ( i + self . env [ ' rank ' ] ) % len ( self . data ) ]
2022-04-20 06:28:03 +00:00
audio = load_audio ( path , 22050 ) . to ( self . dev )
2022-06-27 16:11:23 +00:00
if self . clip :
audio = audio [ : , : 100000 ]
2022-05-09 00:49:39 +00:00
sample , ref , sample_mel , ref_mel , sample_rate = self . diffusion_fn ( audio )
2022-04-20 06:28:03 +00:00
2022-05-07 03:56:49 +00:00
gen_projections . append ( self . project ( sample , sample_rate ) . cpu ( ) ) # Store on CPU to avoid wasting GPU memory.
real_projections . append ( self . project ( ref , sample_rate ) . cpu ( ) )
2022-04-20 06:28:03 +00:00
torchaudio . save ( os . path . join ( save_path , f " { self . env [ ' rank ' ] } _ { i } _gen.wav " ) , sample . squeeze ( 0 ) . cpu ( ) , sample_rate )
torchaudio . save ( os . path . join ( save_path , f " { self . env [ ' rank ' ] } _ { i } _real.wav " ) , ref . cpu ( ) , sample_rate )
2022-05-09 00:49:39 +00:00
torchvision . utils . save_image ( ( sample_mel . unsqueeze ( 1 ) + 1 ) / 2 , os . path . join ( save_path , f " { self . env [ ' rank ' ] } _ { i } _gen_mel.png " ) )
torchvision . utils . save_image ( ( ref_mel . unsqueeze ( 1 ) + 1 ) / 2 , os . path . join ( save_path , f " { self . env [ ' rank ' ] } _ { i } _real_mel.png " ) )
2022-05-07 03:56:49 +00:00
gen_projections = torch . stack ( gen_projections , dim = 0 )
real_projections = torch . stack ( real_projections , dim = 0 )
frechet_distance = torch . tensor ( self . compute_frechet_distance ( gen_projections , real_projections ) , device = self . env [ ' device ' ] )
2022-04-20 06:28:03 +00:00
2022-05-07 03:56:49 +00:00
if distributed . is_initialized ( ) and distributed . get_world_size ( ) > 1 :
distributed . all_reduce ( frechet_distance )
2022-05-09 00:49:39 +00:00
frechet_distance = frechet_distance / distributed . get_world_size ( )
2022-04-20 06:28:03 +00:00
self . model . train ( )
torch . set_rng_state ( rng_state )
# Put modules used for evaluation back into CPU memory.
for k , mod in self . local_modules . items ( ) :
self . local_modules [ k ] = mod . cpu ( )
2022-05-30 22:25:33 +00:00
self . spec_decoder = self . spec_decoder . cpu ( )
2022-04-20 06:28:03 +00:00
return { " frechet_distance " : frechet_distance }
if __name__ == ' __main__ ' :
2022-07-04 14:38:47 +00:00
diffusion = load_model_from_config ( ' X: \\ dlas \\ experiments \\ train_music_tfd12_finetune_ar_outputs.yml ' , ' generator ' ,
2022-04-20 06:28:03 +00:00
also_load_savepoint = False ,
2022-07-04 14:38:47 +00:00
load_path = ' X: \\ dlas \\ experiments \\ train_music_diffusion_tfd12_finetune_from_cheater_ar \\ models \\ 7500_generator.pth '
2022-05-23 16:37:15 +00:00
) . cuda ( )
2022-07-04 14:38:47 +00:00
opt_eval = { ' path ' : ' Y: \\ split \\ yt-music-eval ' , # eval music, mostly electronica. :)
#'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety.
' diffusion_steps ' : 32 ,
' conditioning_free ' : True , ' conditioning_free_k ' : 1 , ' use_ddim ' : True , # 'clip_audio': False,
' diffusion_schedule ' : ' linear ' , ' diffusion_type ' : ' from_ar_prior ' ,
2022-06-12 20:41:06 +00:00
#'partial_low': 128, 'partial_high': 192
}
2022-07-04 14:38:47 +00:00
env = { ' rank ' : 0 , ' base_path ' : ' D: \\ tmp \\ test_eval_music ' , ' step ' : 230 , ' device ' : ' cuda ' , ' opt ' : { } }
2022-04-20 06:28:03 +00:00
eval = MusicDiffusionFid ( diffusion , opt_eval , env )
print ( eval . perform_eval ( ) )