DL-Art-School/codes/trainer/eval/music_diffusion_fid.py

460 lines
26 KiB
Python
Raw Normal View History

2022-06-10 03:41:20 +00:00
import functools
2022-04-20 06:28:03 +00:00
import os
import os.path as osp
from glob import glob
2022-05-16 03:50:54 +00:00
from random import shuffle
2022-05-23 16:37:15 +00:00
from time import time
2022-04-20 06:28:03 +00:00
2022-05-09 00:49:39 +00:00
import numpy as np
2022-04-20 06:28:03 +00:00
import torch
import torchaudio
import torchvision
2022-04-20 06:28:03 +00:00
from pytorch_fid.fid_score import calculate_frechet_distance
from torch import distributed
from tqdm import tqdm
import trainer.eval.evaluator as evaluator
from data.audio.unsupervised_audio_dataset import load_audio
from models.audio.mel2vec import ContrastiveTrainingWrapper
2022-05-06 20:33:44 +00:00
from models.audio.music.unet_diffusion_waveform_gen import DiffusionWaveformGen
2022-05-07 03:56:49 +00:00
from models.clip.contrastive_audio import ContrastiveAudio
2022-04-20 06:28:03 +00:00
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
from models.diffusion.respace import space_timesteps, SpacedDiffusion
2022-05-06 20:33:44 +00:00
from trainer.injectors.audio_injectors import denormalize_mel, TorchMelSpectrogramInjector, pixel_shuffle_1d, \
normalize_mel, KmeansQuantizerInjector
from utils.music_utils import get_music_codegen, get_mel2wav_model, get_cheater_decoder, get_cheater_encoder, \
get_mel2wav_v3_model, get_ar_prior
2022-05-09 00:49:39 +00:00
from utils.util import opt_get, load_model_from_config
2022-04-20 06:28:03 +00:00
class MusicDiffusionFid(evaluator.Evaluator):
"""
Evaluator produces generate from a music diffusion model.
"""
def __init__(self, model, opt_eval, env):
super().__init__(model, opt_eval, env, uses_all_ddp=True)
self.real_path = opt_eval['path']
self.data = self.load_data(self.real_path)
self.clip = opt_get(opt_eval, ['clip_audio'], True) # Recommend setting true for more efficient eval passes.
self.ddim = opt_get(opt_eval, ['use_ddim'], False)
2022-07-11 23:03:56 +00:00
self.causal = opt_get(opt_eval, ['causal'], False)
2022-07-08 20:27:19 +00:00
self.causal_slope = opt_get(opt_eval, ['causal_slope'], 1)
2022-04-20 06:28:03 +00:00
if distributed.is_initialized() and distributed.get_world_size() > 1:
self.skip = distributed.get_world_size() # One batch element per GPU.
else:
self.skip = 1
diffusion_steps = opt_get(opt_eval, ['diffusion_steps'], 50)
diffusion_schedule = opt_get(env['opt'], ['steps', 'generator', 'injectors', 'diffusion', 'beta_schedule', 'schedule_name'], None)
if diffusion_schedule is None:
print("Unable to infer diffusion schedule from master options. Getting it from eval (or guessing).")
diffusion_schedule = opt_get(opt_eval, ['diffusion_schedule'], 'linear')
conditioning_free_diffusion_enabled = opt_get(opt_eval, ['conditioning_free'], False)
conditioning_free_k = opt_get(opt_eval, ['conditioning_free_k'], 1)
self.diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [diffusion_steps]), model_mean_type='epsilon',
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule(diffusion_schedule, 4000),
conditioning_free=conditioning_free_diffusion_enabled, conditioning_free_k=conditioning_free_k)
self.spectral_diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [16 if self.ddim else 100]), model_mean_type='epsilon',
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000),
conditioning_free=False, conditioning_free_k=1)
2022-04-20 06:28:03 +00:00
self.dev = self.env['device']
mode = opt_get(opt_eval, ['diffusion_type'], 'spec_decode')
2022-05-06 20:33:44 +00:00
2022-05-07 03:56:49 +00:00
self.projector = ContrastiveAudio(model_dim=512, transformer_heads=8, dropout=0, encoder_depth=8, mel_channels=256)
2022-05-09 15:19:26 +00:00
self.projector.load_state_dict(torch.load('../experiments/music_eval_projector.pth', map_location=torch.device('cpu')))
2022-05-06 20:33:44 +00:00
self.local_modules = {'projector': self.projector}
2022-05-06 20:33:44 +00:00
if mode == 'spec_decode':
self.diffusion_fn = self.perform_diffusion_spec_decode
2022-06-20 03:04:51 +00:00
self.squeeze_ratio = opt_eval['squeeze_ratio']
elif 'from_codes' == mode:
self.diffusion_fn = self.perform_diffusion_from_codes
self.local_modules['codegen'] = get_music_codegen()
2022-05-30 22:25:33 +00:00
elif 'from_codes_quant' == mode:
self.diffusion_fn = self.perform_diffusion_from_codes_quant
2022-06-10 03:41:20 +00:00
elif 'partial_from_codes_quant' == mode:
self.diffusion_fn = functools.partial(self.perform_partial_diffusion_from_codes_quant,
partial_low=opt_eval['partial_low'],
partial_high=opt_eval['partial_high'])
2022-06-10 03:14:48 +00:00
elif 'from_codes_quant_gradual_decode' == mode:
self.diffusion_fn = self.perform_diffusion_from_codes_quant_gradual_decode
elif 'cheater_gen' == mode:
self.diffusion_fn = self.perform_reconstruction_from_cheater_gen
self.local_modules['cheater_encoder'] = get_cheater_encoder()
self.local_modules['cheater_decoder'] = get_cheater_decoder()
2022-06-28 23:29:21 +00:00
self.cheater_decoder_diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [32]), model_mean_type='epsilon',
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000),
conditioning_free=True, conditioning_free_k=1)
2022-07-11 23:03:56 +00:00
self.spectral_diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [16]), model_mean_type='epsilon',
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000),
conditioning_free=False, conditioning_free_k=1)
self.spec_decoder = get_mel2wav_v3_model() # The only reason the other functions don't use v3 is because earlier models were trained with v1 and I want to keep metrics consistent.
self.local_modules['spec_decoder'] = self.spec_decoder
elif 'from_ar_prior' == mode:
self.diffusion_fn = self.perform_diffusion_from_codes_ar_prior
self.local_modules['cheater_encoder'] = get_cheater_encoder()
2022-07-11 23:03:56 +00:00
self.local_modules['cheater_decoder'] = get_cheater_decoder()
self.cheater_decoder_diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [32]), model_mean_type='epsilon',
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000),
conditioning_free=True, conditioning_free_k=1)
self.kmeans_inj = KmeansQuantizerInjector({'centroids': '../experiments/music_k_means_centroids.pth', 'in': 'in', 'out': 'out'}, {})
self.local_modules['ar_prior'] = get_ar_prior()
self.spec_decoder = get_mel2wav_v3_model()
self.local_modules['spec_decoder'] = self.spec_decoder
if not hasattr(self, 'spec_decoder'):
self.spec_decoder = get_mel2wav_model()
self.local_modules['spec_decoder'] = self.spec_decoder
2022-05-23 16:38:28 +00:00
self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 11000, 'filter_length': 16000,
'normalize': True, 'in': 'in', 'out': 'out'}, {})
2022-04-20 06:28:03 +00:00
def load_data(self, path):
return list(glob(f'{path}/*.wav'))
2022-05-06 20:33:44 +00:00
def perform_diffusion_spec_decode(self, audio, sample_rate=22050):
real_resampled = audio
2022-04-28 16:08:55 +00:00
audio = audio.unsqueeze(0)
2022-06-20 03:04:51 +00:00
output_shape = (1, self.squeeze_ratio, audio.shape[-1] // self.squeeze_ratio)
mel = self.spec_fn({'in': audio})['out']
2022-05-23 16:37:15 +00:00
gen = self.diffuser.p_sample_loop(self.model, output_shape,
model_kwargs={'codes': mel})
2022-06-20 03:04:51 +00:00
gen = pixel_shuffle_1d(gen, self.squeeze_ratio)
2022-05-06 20:33:44 +00:00
return gen, real_resampled, normalize_mel(self.spec_fn({'in': gen})['out']), normalize_mel(mel), sample_rate
2022-05-06 20:33:44 +00:00
def perform_diffusion_from_codes(self, audio, sample_rate=22050):
real_resampled = audio
audio = audio.unsqueeze(0)
mel = self.spec_fn({'in': audio})['out']
codegen = self.local_modules['codegen'].to(mel.device)
2022-05-29 05:19:36 +00:00
codes = codegen.get_codes(mel, project=True)
2022-05-23 05:10:58 +00:00
mel_norm = normalize_mel(mel)
2022-05-27 17:40:47 +00:00
gen_mel = self.diffuser.p_sample_loop(self.model, mel_norm.shape,
2022-05-30 22:25:33 +00:00
model_kwargs={'codes': codes, 'conditioning_input': torch.zeros_like(mel_norm[:,:,:390])})
2022-05-23 05:10:58 +00:00
gen_mel_denorm = denormalize_mel(gen_mel)
output_shape = (1,16,audio.shape[-1]//16)
self.spec_decoder = self.spec_decoder.to(audio.device)
gen_wav = self.spectral_diffuser.p_sample_loop(self.spec_decoder, output_shape,
2022-05-30 22:25:33 +00:00
model_kwargs={'aligned_conditioning': gen_mel_denorm})
gen_wav = pixel_shuffle_1d(gen_wav, 16)
return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate
def perform_diffusion_from_codes_quant(self, audio, sample_rate=22050):
real_resampled = audio
2022-05-30 22:25:33 +00:00
audio = audio.unsqueeze(0)
mel = self.spec_fn({'in': audio})['out']
mel_norm = normalize_mel(mel)
#def denoising_fn(x):
# q9 = torch.quantile(x, q=.95, dim=-1).unsqueeze(-1)
# s = q9.clamp(1, 9999999999)
# x = x.clamp(-s, s) / s
# return x
gen_mel = self.diffuser.p_sample_loop(self.model, mel_norm.shape, #denoised_fn=denoising_fn, clip_denoised=False,
model_kwargs={'truth_mel': mel_norm,
'conditioning_input': mel_norm,
2022-06-01 22:35:15 +00:00
'disable_diversity': True})
2022-05-30 22:25:33 +00:00
gen_mel_denorm = denormalize_mel(gen_mel)
output_shape = (1,16,audio.shape[-1]//16)
self.spec_decoder = self.spec_decoder.to(audio.device)
gen_wav = self.spectral_diffuser.p_sample_loop(self.spec_decoder, output_shape,
2022-05-30 22:25:33 +00:00
model_kwargs={'aligned_conditioning': gen_mel_denorm})
2022-05-23 05:10:58 +00:00
gen_wav = pixel_shuffle_1d(gen_wav, 16)
real_wav = self.spectral_diffuser.p_sample_loop(self.spec_decoder, output_shape,
model_kwargs={'aligned_conditioning': mel})
real_wav = pixel_shuffle_1d(real_wav, 16)
return gen_wav, real_wav.squeeze(0), gen_mel, mel_norm, sample_rate
2022-05-23 05:10:58 +00:00
2022-06-10 03:41:20 +00:00
def perform_partial_diffusion_from_codes_quant(self, audio, sample_rate=22050, partial_low=0, partial_high=256):
real_resampled = audio
2022-06-10 03:41:20 +00:00
audio = audio.unsqueeze(0)
mel = self.spec_fn({'in': audio})['out']
mel_norm = normalize_mel(mel)
mask = torch.ones_like(mel_norm)
mask[:, partial_low:partial_high] = 0 # This is the channel region that the model will predict.
gen_mel = self.diffuser.p_sample_loop_with_guidance(self.model,
guidance_input=mel_norm, mask=mask,
model_kwargs={'truth_mel': mel,
'conditioning_input': torch.zeros_like(mel_norm[:,:,:390]),
'disable_diversity': True})
gen_mel_denorm = denormalize_mel(gen_mel)
output_shape = (1,16,audio.shape[-1]//16)
self.spec_decoder = self.spec_decoder.to(audio.device)
gen_wav = self.spectral_diffuser.p_sample_loop(self.spec_decoder, output_shape,
2022-06-10 03:41:20 +00:00
model_kwargs={'aligned_conditioning': gen_mel_denorm})
gen_wav = pixel_shuffle_1d(gen_wav, 16)
return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate
2022-06-10 03:14:48 +00:00
def perform_diffusion_from_codes_quant_gradual_decode(self, audio, sample_rate=22050):
real_resampled = audio
2022-06-10 03:14:48 +00:00
audio = audio.unsqueeze(0)
mel = self.spec_fn({'in': audio})['out']
mel_norm = normalize_mel(mel)
guidance = torch.zeros_like(mel_norm)
mask = torch.zeros_like(mel_norm)
GRADS = 4
for k in range(GRADS):
gen_mel = self.diffuser.p_sample_loop_with_guidance(self.model,
guidance_input=guidance, mask=mask,
model_kwargs={'truth_mel': mel,
'conditioning_input': torch.zeros_like(mel_norm[:,:,:390]),
'disable_diversity': True})
pk = int(k*(mel_norm.shape[1]/GRADS))
ek = int((k+1)*(mel_norm.shape[1]/GRADS))
guidance[:, pk:ek] = gen_mel[:, pk:ek]
mask[:, :ek] = 1
gen_mel_denorm = denormalize_mel(gen_mel)
output_shape = (1,16,audio.shape[-1]//16)
self.spec_decoder = self.spec_decoder.to(audio.device)
gen_wav = self.diffuser.p_sample_loop(self.spec_decoder, output_shape,
model_kwargs={'aligned_conditioning': gen_mel_denorm})
gen_wav = pixel_shuffle_1d(gen_wav, 16)
return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate
def perform_reconstruction_from_cheater_gen(self, audio, sample_rate=22050):
audio = audio.unsqueeze(0)
mel = self.spec_fn({'in': audio})['out']
mel_norm = normalize_mel(mel)
cheater = self.local_modules['cheater_encoder'].to(audio.device)(mel_norm)
# 1. Generate the cheater latent using the input as a reference.
2022-07-09 21:35:09 +00:00
sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop
# center-pad the conditioning input (the center isn't actually used). this is hack for giving tfdpc5 a bigger working context.
cheater_padded = torch.cat([cheater[:,:,cheater.shape[-1]//2:], torch.zeros(1,256,160, device=cheater.device), cheater[:,:,:cheater.shape[-1]//2]], dim=-1)
gen_cheater = sampler(self.model, cheater.shape, progress=True,
2022-07-09 21:35:09 +00:00
causal=self.causal, causal_slope=self.causal_slope,
model_kwargs={'conditioning_input': cheater_padded, 'cond_start': 80})
# 2. Decode the cheater into a MEL
gen_mel = self.cheater_decoder_diffuser.ddim_sample_loop(self.local_modules['cheater_decoder'].diff.to(audio.device), (1,256,gen_cheater.shape[-1]*16), progress=True,
model_kwargs={'codes': gen_cheater.permute(0,2,1)})
2022-07-08 18:30:22 +00:00
# 3. And then the MEL back into a spectrogram
output_shape = (1,16,audio.shape[-1]//16)
self.spec_decoder = self.spec_decoder.to(audio.device)
gen_mel_denorm = denormalize_mel(gen_mel)
2022-07-11 23:03:56 +00:00
gen_wav = self.spectral_diffuser.ddim_sample_loop(self.spec_decoder, output_shape,
model_kwargs={'codes': gen_mel_denorm})
gen_wav = pixel_shuffle_1d(gen_wav, 16)
2022-07-08 18:30:22 +00:00
2022-07-11 23:03:56 +00:00
real_wav = self.spectral_diffuser.ddim_sample_loop(self.spec_decoder, output_shape,
model_kwargs={'codes': mel})
real_wav = pixel_shuffle_1d(real_wav, 16)
return gen_wav, real_wav.squeeze(0), gen_mel, mel_norm, sample_rate
def perform_diffusion_from_codes_ar_prior(self, audio, sample_rate=22050):
audio = audio.unsqueeze(0)
mel = self.spec_fn({'in': audio})['out']
mel_norm = normalize_mel(mel)
cheater = self.local_modules['cheater_encoder'].to(audio.device)(mel_norm)
cheater_codes = self.kmeans_inj({'in': cheater})['out']
ar_latent = self.local_modules['ar_prior'].to(audio.device)(cheater_codes, cheater, return_latent=True)
2022-07-11 23:03:56 +00:00
# 1. Generate the cheater latent using the input as a reference.
sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop
gen_cheater = sampler(self.model, cheater.shape, progress=True,
causal=self.causal, causal_slope=self.causal_slope,
model_kwargs={'codes': ar_latent})
2022-07-11 23:03:56 +00:00
# 2. Decode the cheater into a MEL
gen_mel = self.cheater_decoder_diffuser.ddim_sample_loop(self.local_modules['cheater_decoder'].diff.to(audio.device), (1,256,gen_cheater.shape[-1]*16), progress=True,
model_kwargs={'codes': gen_cheater.permute(0,2,1)})
gen_mel_denorm = denormalize_mel(gen_mel)
2022-07-11 23:03:56 +00:00
# 3. Decode into waveform.
output_shape = (1,16,audio.shape[-1]//16)
self.spec_decoder = self.spec_decoder.to(audio.device)
2022-07-11 23:03:56 +00:00
gen_wav = self.spectral_diffuser.ddim_sample_loop(self.spec_decoder, output_shape, model_kwargs={'codes': gen_mel_denorm})
gen_wav = pixel_shuffle_1d(gen_wav, 16)
2022-07-11 23:03:56 +00:00
real_wav = self.spectral_diffuser.ddim_sample_loop(self.spec_decoder, output_shape, model_kwargs={'codes': mel})
real_wav = pixel_shuffle_1d(real_wav, 16)
return gen_wav, real_wav.squeeze(0), gen_mel, mel_norm, sample_rate
2022-07-08 18:30:22 +00:00
def perform_fake_ar_reconstruction_from_cheater_gen(self, audio, sample_rate=22050):
assert self.ddim, "DDIM mode expected for reconstructing cheater gen. Do you like to waste resources??"
audio = audio.unsqueeze(0)
mel = self.spec_fn({'in': audio})['out']
mel_norm = normalize_mel(mel)
cheater = self.local_modules['cheater_encoder'].to(audio.device)(mel_norm)
# 1. Generate the cheater latent using the input as a reference.
def diffuse(i, ref):
mask = torch.zeros_like(ref)
mask[:,:,:i] = 1
return self.diffuser.p_sample_loop_with_guidance(self.model, ref, mask, model_kwargs={'conditioning_input': cheater})
gen_cheater = torch.randn_like(cheater)
for i in range(cheater.shape[-1]):
gen_cheater = diffuse(i, gen_cheater)
if i > 128:
# abort early.
gen_cheater = gen_cheater[:,:,:128]
break
# 2. Decode the cheater into a MEL. This operation and the next need to be chunked to make them feasible to perform within GPU memory.
chunks = torch.split(gen_cheater, 64, dim=-1)
gen_wavs = []
for chunk in tqdm(chunks):
gen_mel = self.cheater_decoder_diffuser.ddim_sample_loop(self.local_modules['cheater_decoder'].diff.to(audio.device), (1,256,chunk.shape[-1]*16), progress=True,
model_kwargs={'codes': chunk.permute(0,2,1)})
# 3. And then the MEL back into a spectrogram
output_shape = (1,16,audio.shape[-1]//(16*len(chunks)))
self.spec_decoder = self.spec_decoder.to(audio.device)
gen_mel_denorm = denormalize_mel(gen_mel)
gen_wav = self.spectral_diffuser.p_sample_loop(self.spec_decoder, output_shape,
model_kwargs={'codes': gen_mel_denorm})
gen_wav = pixel_shuffle_1d(gen_wav, 16)
gen_wavs.append(gen_wav)
gen_wav = torch.cat(gen_wavs, dim=-1)
""" How to do progressive, causal decoding of the TFD diffuser:
MAX_CONTEXT = 64
def diffuse(start, len, guidance):
mask = torch.zeros_like(guidance)
mask[:,:,:(len-start)] = 1
return self.cheater_decoder_diffuser.p_sample_loop_with_guidance(self.local_modules['cheater_decoder'].diff.to(audio.device),
guidance_input=guidance, mask=mask,
model_kwargs={'codes': gen_cheater[:,:,start:start+MAX_CONTEXT].permute(0,2,1)})
guidance_mel = torch.zeros((1,256,MAX_CONTEXT*16), device=mel.device)
gen_mel = torch.zeros((1,256,0), device=mel.device)
for i in tqdm(list(range(gen_cheater.shape[-1]))):
start = max(0, i-MAX_CONTEXT-1)
l = min(16*(MAX_CONTEXT-1), i*16)
ngm = diffuse(start, l, guidance_mel)
gen_mel = torch.cat([gen_mel, ngm[:,:,l:l+16]], dim=-1)
if gen_mel.shape[-1] < guidance_mel.shape[-1]:
guidance_mel[:,:,:gen_mel.shape[-1]] = gen_mel
else:
guidance_mel = gen_mel[:,:,-guidance_mel.shape[-1]:]
chunks = torch.split(gen_mel, MAX_CONTEXT*16, dim=-1)
gen_wavs = []
for chunk_mel in tqdm(chunks):
# 3. And then the MEL back into a spectrogram
output_shape = (1,16,audio.shape[-1]//(16*len(chunks)))
self.spec_decoder = self.spec_decoder.to(audio.device)
gen_mel_denorm = denormalize_mel(chunk_mel)
gen_wav = self.spectral_diffuser.p_sample_loop(self.spec_decoder, output_shape,
model_kwargs={'codes': gen_mel_denorm})
gen_wav = pixel_shuffle_1d(gen_wav, 16)
gen_wavs.append(gen_wav)
gen_wav = torch.cat(gen_wavs, dim=-1)
"""
if audio.shape[-1] < 40 * 22050:
real_wav = self.spectral_diffuser.p_sample_loop(self.spec_decoder, output_shape,
model_kwargs={'codes': mel})
real_wav = pixel_shuffle_1d(real_wav, 16)
else:
real_wav = audio # TODO: chunk like above.
return gen_wav, real_wav.squeeze(0), gen_mel, mel_norm, sample_rate
2022-05-07 03:56:49 +00:00
def project(self, sample, sample_rate):
2022-04-20 06:28:03 +00:00
sample = torchaudio.functional.resample(sample, sample_rate, 22050)
2022-05-07 03:56:49 +00:00
mel = self.spec_fn({'in': sample})['out']
projection = self.projector.project(mel)
return projection.squeeze(0) # Getting rid of the batch dimension means it's just [hidden_dim]
2022-04-20 06:28:03 +00:00
def compute_frechet_distance(self, proj1, proj2):
# I really REALLY FUCKING HATE that this is going to numpy. Why does "pytorch_fid" operate in numpy land. WHY?
proj1 = proj1.cpu().numpy()
proj2 = proj2.cpu().numpy()
mu1 = np.mean(proj1, axis=0)
mu2 = np.mean(proj2, axis=0)
sigma1 = np.cov(proj1, rowvar=False)
sigma2 = np.cov(proj2, rowvar=False)
2022-05-28 02:24:16 +00:00
try:
return torch.tensor(calculate_frechet_distance(mu1, sigma1, mu2, sigma2))
except:
return 0
2022-04-20 06:28:03 +00:00
def perform_eval(self):
save_path = osp.join(self.env['base_path'], "../", "audio_eval", str(self.env["step"]))
os.makedirs(save_path, exist_ok=True)
2022-05-07 03:56:49 +00:00
self.projector = self.projector.to(self.dev)
self.projector.eval()
2022-04-20 06:28:03 +00:00
# Attempt to fix the random state as much as possible. RNG state will be restored before returning.
rng_state = torch.get_rng_state()
torch.manual_seed(5)
self.model.eval()
with torch.no_grad():
gen_projections = []
real_projections = []
for i in tqdm(list(range(0, len(self.data), self.skip))):
2022-06-01 22:35:15 +00:00
path = self.data[(i + self.env['rank']) % len(self.data)]
2022-07-08 20:27:19 +00:00
audio = load_audio(path, 22050).to(self.dev)
#audio = load_audio('C:\\Users\\James\\Music\\another_longer_sample.wav', 22050).to(self.dev) # <- hack, remove it!
#audio = audio[:, :1764000]
if self.clip:
audio = audio[:, :100000]
2022-05-09 00:49:39 +00:00
sample, ref, sample_mel, ref_mel, sample_rate = self.diffusion_fn(audio)
2022-04-20 06:28:03 +00:00
2022-05-07 03:56:49 +00:00
gen_projections.append(self.project(sample, sample_rate).cpu()) # Store on CPU to avoid wasting GPU memory.
real_projections.append(self.project(ref, sample_rate).cpu())
2022-04-20 06:28:03 +00:00
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_gen.wav"), sample.squeeze(0).cpu(), sample_rate)
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_real.wav"), ref.cpu(), sample_rate)
2022-05-09 00:49:39 +00:00
torchvision.utils.save_image((sample_mel.unsqueeze(1) + 1) / 2, os.path.join(save_path, f"{self.env['rank']}_{i}_gen_mel.png"))
torchvision.utils.save_image((ref_mel.unsqueeze(1) + 1) / 2, os.path.join(save_path, f"{self.env['rank']}_{i}_real_mel.png"))
2022-05-07 03:56:49 +00:00
gen_projections = torch.stack(gen_projections, dim=0)
real_projections = torch.stack(real_projections, dim=0)
frechet_distance = torch.tensor(self.compute_frechet_distance(gen_projections, real_projections), device=self.env['device'])
2022-04-20 06:28:03 +00:00
2022-05-07 03:56:49 +00:00
if distributed.is_initialized() and distributed.get_world_size() > 1:
distributed.all_reduce(frechet_distance)
2022-05-09 00:49:39 +00:00
frechet_distance = frechet_distance / distributed.get_world_size()
2022-04-20 06:28:03 +00:00
self.model.train()
torch.set_rng_state(rng_state)
# Put modules used for evaluation back into CPU memory.
for k, mod in self.local_modules.items():
self.local_modules[k] = mod.cpu()
2022-05-30 22:25:33 +00:00
self.spec_decoder = self.spec_decoder.cpu()
2022-04-20 06:28:03 +00:00
return {"frechet_distance": frechet_distance}
if __name__ == '__main__':
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_tfd_and_cheater.yml', 'generator',
2022-04-20 06:28:03 +00:00
also_load_savepoint=False,
load_path='X:\\dlas\\experiments\\train_music_diffusion_tfd_and_cheater\\models\\93500_generator_ema.pth'
2022-05-23 16:37:15 +00:00
).cuda()
opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :)
#'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety.
2022-07-18 00:24:33 +00:00
'diffusion_steps': 256, # basis: 192
'conditioning_free': True, 'conditioning_free_k': 1, 'use_ddim': True, 'clip_audio': True,
'diffusion_schedule': 'linear', 'diffusion_type': 'from_codes_quant',
#'causal': True, 'causal_slope': 4,
#'partial_low': 128, 'partial_high': 192
}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 7, 'device': 'cuda', 'opt': {}}
2022-04-20 06:28:03 +00:00
eval = MusicDiffusionFid(diffusion, opt_eval, env)
2022-07-11 23:03:56 +00:00
fds = []
for i in range(2):
res = eval.perform_eval()
print(res)
fds.append(res['frechet_distance'])
print(fds)