2022-04-20 06:28:03 +00:00
|
|
|
import os
|
|
|
|
import os.path as osp
|
|
|
|
from glob import glob
|
2022-05-16 03:50:54 +00:00
|
|
|
from random import shuffle
|
2022-05-23 16:37:15 +00:00
|
|
|
from time import time
|
2022-04-20 06:28:03 +00:00
|
|
|
|
2022-05-09 00:49:39 +00:00
|
|
|
import numpy as np
|
2022-04-20 06:28:03 +00:00
|
|
|
import torch
|
|
|
|
import torchaudio
|
2022-05-02 05:04:56 +00:00
|
|
|
import torchvision
|
2022-04-20 06:28:03 +00:00
|
|
|
from pytorch_fid.fid_score import calculate_frechet_distance
|
|
|
|
from torch import distributed
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
import trainer.eval.evaluator as evaluator
|
|
|
|
from data.audio.unsupervised_audio_dataset import load_audio
|
2022-05-22 11:23:54 +00:00
|
|
|
from models.audio.mel2vec import ContrastiveTrainingWrapper
|
2022-05-06 20:33:44 +00:00
|
|
|
from models.audio.music.unet_diffusion_waveform_gen import DiffusionWaveformGen
|
2022-05-07 03:56:49 +00:00
|
|
|
from models.clip.contrastive_audio import ContrastiveAudio
|
2022-04-20 06:28:03 +00:00
|
|
|
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
|
|
|
|
from models.diffusion.respace import space_timesteps, SpacedDiffusion
|
2022-05-06 20:33:44 +00:00
|
|
|
from trainer.injectors.audio_injectors import denormalize_mel, TorchMelSpectrogramInjector, pixel_shuffle_1d, \
|
|
|
|
normalize_mel
|
2022-05-22 11:23:54 +00:00
|
|
|
from utils.music_utils import get_music_codegen, get_mel2wav_model
|
2022-05-09 00:49:39 +00:00
|
|
|
from utils.util import opt_get, load_model_from_config
|
2022-04-20 06:28:03 +00:00
|
|
|
|
|
|
|
|
|
|
|
class MusicDiffusionFid(evaluator.Evaluator):
|
|
|
|
"""
|
|
|
|
Evaluator produces generate from a music diffusion model.
|
|
|
|
"""
|
|
|
|
def __init__(self, model, opt_eval, env):
|
|
|
|
super().__init__(model, opt_eval, env, uses_all_ddp=True)
|
|
|
|
self.real_path = opt_eval['path']
|
|
|
|
self.data = self.load_data(self.real_path)
|
|
|
|
if distributed.is_initialized() and distributed.get_world_size() > 1:
|
|
|
|
self.skip = distributed.get_world_size() # One batch element per GPU.
|
|
|
|
else:
|
|
|
|
self.skip = 1
|
|
|
|
diffusion_steps = opt_get(opt_eval, ['diffusion_steps'], 50)
|
|
|
|
diffusion_schedule = opt_get(env['opt'], ['steps', 'generator', 'injectors', 'diffusion', 'beta_schedule', 'schedule_name'], None)
|
|
|
|
if diffusion_schedule is None:
|
|
|
|
print("Unable to infer diffusion schedule from master options. Getting it from eval (or guessing).")
|
|
|
|
diffusion_schedule = opt_get(opt_eval, ['diffusion_schedule'], 'linear')
|
|
|
|
conditioning_free_diffusion_enabled = opt_get(opt_eval, ['conditioning_free'], False)
|
|
|
|
conditioning_free_k = opt_get(opt_eval, ['conditioning_free_k'], 1)
|
|
|
|
self.diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [diffusion_steps]), model_mean_type='epsilon',
|
|
|
|
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule(diffusion_schedule, 4000),
|
|
|
|
conditioning_free=conditioning_free_diffusion_enabled, conditioning_free_k=conditioning_free_k)
|
|
|
|
self.dev = self.env['device']
|
|
|
|
mode = opt_get(opt_eval, ['diffusion_type'], 'tts')
|
2022-05-06 20:33:44 +00:00
|
|
|
|
2022-05-22 11:23:54 +00:00
|
|
|
self.spec_decoder = get_mel2wav_model()
|
2022-05-07 03:56:49 +00:00
|
|
|
self.projector = ContrastiveAudio(model_dim=512, transformer_heads=8, dropout=0, encoder_depth=8, mel_channels=256)
|
2022-05-09 15:19:26 +00:00
|
|
|
self.projector.load_state_dict(torch.load('../experiments/music_eval_projector.pth', map_location=torch.device('cpu')))
|
2022-05-07 03:56:49 +00:00
|
|
|
self.local_modules = {'spec_decoder': self.spec_decoder, 'projector': self.projector}
|
2022-05-06 20:33:44 +00:00
|
|
|
|
|
|
|
if mode == 'spec_decode':
|
|
|
|
self.diffusion_fn = self.perform_diffusion_spec_decode
|
2022-05-22 11:23:54 +00:00
|
|
|
elif 'from_codes' == mode:
|
|
|
|
self.diffusion_fn = self.perform_diffusion_from_codes
|
|
|
|
self.local_modules['codegen'] = get_music_codegen()
|
2022-05-30 22:25:33 +00:00
|
|
|
elif 'from_codes_quant' == mode:
|
|
|
|
self.diffusion_fn = self.perform_diffusion_from_codes_quant
|
2022-05-23 16:38:28 +00:00
|
|
|
self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 11000, 'filter_length': 16000,
|
2022-05-24 20:02:33 +00:00
|
|
|
'normalize': True, 'in': 'in', 'out': 'out'}, {})
|
2022-04-20 06:28:03 +00:00
|
|
|
|
|
|
|
def load_data(self, path):
|
|
|
|
return list(glob(f'{path}/*.wav'))
|
|
|
|
|
2022-05-06 20:33:44 +00:00
|
|
|
def perform_diffusion_spec_decode(self, audio, sample_rate=22050):
|
2022-04-20 06:28:03 +00:00
|
|
|
if sample_rate != sample_rate:
|
|
|
|
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
|
|
|
else:
|
|
|
|
real_resampled = audio
|
2022-04-28 16:08:55 +00:00
|
|
|
audio = audio.unsqueeze(0)
|
2022-05-05 02:29:23 +00:00
|
|
|
output_shape = (1, 16, audio.shape[-1] // 16)
|
2022-05-02 05:04:56 +00:00
|
|
|
mel = self.spec_fn({'in': audio})['out']
|
2022-05-23 16:37:15 +00:00
|
|
|
gen = self.diffuser.p_sample_loop(self.model, output_shape,
|
2022-05-02 05:04:56 +00:00
|
|
|
model_kwargs={'aligned_conditioning': mel})
|
2022-05-05 02:29:23 +00:00
|
|
|
gen = pixel_shuffle_1d(gen, 16)
|
2022-05-06 20:33:44 +00:00
|
|
|
|
2022-05-24 20:02:33 +00:00
|
|
|
return gen, real_resampled, normalize_mel(self.spec_fn({'in': gen})['out']), normalize_mel(mel), sample_rate
|
2022-05-06 20:33:44 +00:00
|
|
|
|
2022-05-22 11:23:54 +00:00
|
|
|
def perform_diffusion_from_codes(self, audio, sample_rate=22050):
|
|
|
|
if sample_rate != sample_rate:
|
|
|
|
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
|
|
|
else:
|
|
|
|
real_resampled = audio
|
|
|
|
audio = audio.unsqueeze(0)
|
|
|
|
|
|
|
|
mel = self.spec_fn({'in': audio})['out']
|
|
|
|
codegen = self.local_modules['codegen'].to(mel.device)
|
2022-05-29 05:19:36 +00:00
|
|
|
codes = codegen.get_codes(mel, project=True)
|
2022-05-23 05:10:58 +00:00
|
|
|
mel_norm = normalize_mel(mel)
|
2022-05-27 17:40:47 +00:00
|
|
|
gen_mel = self.diffuser.p_sample_loop(self.model, mel_norm.shape,
|
2022-05-30 22:25:33 +00:00
|
|
|
model_kwargs={'codes': codes, 'conditioning_input': torch.zeros_like(mel_norm[:,:,:390])})
|
2022-05-23 05:10:58 +00:00
|
|
|
|
|
|
|
gen_mel_denorm = denormalize_mel(gen_mel)
|
|
|
|
output_shape = (1,16,audio.shape[-1]//16)
|
|
|
|
self.spec_decoder = self.spec_decoder.to(audio.device)
|
2022-05-30 22:25:33 +00:00
|
|
|
gen_wav = self.diffuser.p_sample_loop(self.spec_decoder, output_shape,
|
|
|
|
model_kwargs={'aligned_conditioning': gen_mel_denorm})
|
|
|
|
gen_wav = pixel_shuffle_1d(gen_wav, 16)
|
|
|
|
|
|
|
|
return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate
|
|
|
|
|
|
|
|
def perform_diffusion_from_codes_quant(self, audio, sample_rate=22050):
|
|
|
|
if sample_rate != sample_rate:
|
|
|
|
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
|
|
|
else:
|
|
|
|
real_resampled = audio
|
|
|
|
audio = audio.unsqueeze(0)
|
|
|
|
|
|
|
|
mel = self.spec_fn({'in': audio})['out']
|
|
|
|
mel_norm = normalize_mel(mel)
|
2022-06-04 16:15:31 +00:00
|
|
|
#def denoising_fn(x):
|
|
|
|
# q9 = torch.quantile(x, q=.95, dim=-1).unsqueeze(-1)
|
|
|
|
# s = q9.clamp(1, 9999999999)
|
|
|
|
# x = x.clamp(-s, s) / s
|
|
|
|
# return x
|
|
|
|
gen_mel = self.diffuser.p_sample_loop(self.model, mel_norm.shape, #denoised_fn=denoising_fn, clip_denoised=False,
|
2022-05-30 22:25:33 +00:00
|
|
|
model_kwargs={'truth_mel': mel,
|
2022-06-01 22:35:15 +00:00
|
|
|
'conditioning_input': torch.zeros_like(mel_norm[:,:,:390]),
|
|
|
|
'disable_diversity': True})
|
2022-05-30 22:25:33 +00:00
|
|
|
|
|
|
|
gen_mel_denorm = denormalize_mel(gen_mel)
|
|
|
|
output_shape = (1,16,audio.shape[-1]//16)
|
|
|
|
self.spec_decoder = self.spec_decoder.to(audio.device)
|
|
|
|
gen_wav = self.diffuser.p_sample_loop(self.spec_decoder, output_shape,
|
|
|
|
model_kwargs={'aligned_conditioning': gen_mel_denorm})
|
2022-05-23 05:10:58 +00:00
|
|
|
gen_wav = pixel_shuffle_1d(gen_wav, 16)
|
|
|
|
|
|
|
|
return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate
|
|
|
|
|
2022-05-22 11:23:54 +00:00
|
|
|
|
2022-05-07 03:56:49 +00:00
|
|
|
def project(self, sample, sample_rate):
|
2022-04-20 06:28:03 +00:00
|
|
|
sample = torchaudio.functional.resample(sample, sample_rate, 22050)
|
2022-05-07 03:56:49 +00:00
|
|
|
mel = self.spec_fn({'in': sample})['out']
|
|
|
|
projection = self.projector.project(mel)
|
|
|
|
return projection.squeeze(0) # Getting rid of the batch dimension means it's just [hidden_dim]
|
2022-04-20 06:28:03 +00:00
|
|
|
|
|
|
|
def compute_frechet_distance(self, proj1, proj2):
|
|
|
|
# I really REALLY FUCKING HATE that this is going to numpy. Why does "pytorch_fid" operate in numpy land. WHY?
|
|
|
|
proj1 = proj1.cpu().numpy()
|
|
|
|
proj2 = proj2.cpu().numpy()
|
|
|
|
mu1 = np.mean(proj1, axis=0)
|
|
|
|
mu2 = np.mean(proj2, axis=0)
|
|
|
|
sigma1 = np.cov(proj1, rowvar=False)
|
|
|
|
sigma2 = np.cov(proj2, rowvar=False)
|
2022-05-28 02:24:16 +00:00
|
|
|
try:
|
|
|
|
return torch.tensor(calculate_frechet_distance(mu1, sigma1, mu2, sigma2))
|
|
|
|
except:
|
|
|
|
return 0
|
2022-04-20 06:28:03 +00:00
|
|
|
|
|
|
|
def perform_eval(self):
|
|
|
|
save_path = osp.join(self.env['base_path'], "../", "audio_eval", str(self.env["step"]))
|
|
|
|
os.makedirs(save_path, exist_ok=True)
|
|
|
|
|
2022-05-07 03:56:49 +00:00
|
|
|
self.projector = self.projector.to(self.dev)
|
|
|
|
self.projector.eval()
|
2022-04-20 06:28:03 +00:00
|
|
|
|
|
|
|
# Attempt to fix the random state as much as possible. RNG state will be restored before returning.
|
|
|
|
rng_state = torch.get_rng_state()
|
|
|
|
torch.manual_seed(5)
|
|
|
|
self.model.eval()
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
gen_projections = []
|
|
|
|
real_projections = []
|
|
|
|
for i in tqdm(list(range(0, len(self.data), self.skip))):
|
2022-06-01 22:35:15 +00:00
|
|
|
path = self.data[(i + self.env['rank']) % len(self.data)]
|
2022-04-20 06:28:03 +00:00
|
|
|
audio = load_audio(path, 22050).to(self.dev)
|
2022-05-09 00:54:09 +00:00
|
|
|
audio = audio[:, :22050*10]
|
2022-05-09 00:49:39 +00:00
|
|
|
sample, ref, sample_mel, ref_mel, sample_rate = self.diffusion_fn(audio)
|
2022-04-20 06:28:03 +00:00
|
|
|
|
2022-05-07 03:56:49 +00:00
|
|
|
gen_projections.append(self.project(sample, sample_rate).cpu()) # Store on CPU to avoid wasting GPU memory.
|
|
|
|
real_projections.append(self.project(ref, sample_rate).cpu())
|
2022-04-20 06:28:03 +00:00
|
|
|
|
|
|
|
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_gen.wav"), sample.squeeze(0).cpu(), sample_rate)
|
|
|
|
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_real.wav"), ref.cpu(), sample_rate)
|
2022-05-09 00:49:39 +00:00
|
|
|
torchvision.utils.save_image((sample_mel.unsqueeze(1) + 1) / 2, os.path.join(save_path, f"{self.env['rank']}_{i}_gen_mel.png"))
|
|
|
|
torchvision.utils.save_image((ref_mel.unsqueeze(1) + 1) / 2, os.path.join(save_path, f"{self.env['rank']}_{i}_real_mel.png"))
|
2022-05-07 03:56:49 +00:00
|
|
|
gen_projections = torch.stack(gen_projections, dim=0)
|
|
|
|
real_projections = torch.stack(real_projections, dim=0)
|
|
|
|
frechet_distance = torch.tensor(self.compute_frechet_distance(gen_projections, real_projections), device=self.env['device'])
|
2022-04-20 06:28:03 +00:00
|
|
|
|
2022-05-07 03:56:49 +00:00
|
|
|
if distributed.is_initialized() and distributed.get_world_size() > 1:
|
|
|
|
distributed.all_reduce(frechet_distance)
|
2022-05-09 00:49:39 +00:00
|
|
|
frechet_distance = frechet_distance / distributed.get_world_size()
|
2022-04-20 06:28:03 +00:00
|
|
|
|
|
|
|
self.model.train()
|
|
|
|
torch.set_rng_state(rng_state)
|
|
|
|
|
|
|
|
# Put modules used for evaluation back into CPU memory.
|
|
|
|
for k, mod in self.local_modules.items():
|
|
|
|
self.local_modules[k] = mod.cpu()
|
2022-05-30 22:25:33 +00:00
|
|
|
self.spec_decoder = self.spec_decoder.cpu()
|
2022-04-20 06:28:03 +00:00
|
|
|
|
|
|
|
return {"frechet_distance": frechet_distance}
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
2022-06-04 16:15:31 +00:00
|
|
|
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_tfd5_quant\\train_music_diffusion_tfd5_quant.yml', 'generator',
|
2022-04-20 06:28:03 +00:00
|
|
|
also_load_savepoint=False,
|
2022-06-04 16:15:31 +00:00
|
|
|
load_path='X:\\dlas\\experiments\\train_music_diffusion_tfd5_quant\\models\\40500_generator_ema.pth'
|
2022-05-23 16:37:15 +00:00
|
|
|
).cuda()
|
2022-05-30 22:25:33 +00:00
|
|
|
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 100,
|
2022-06-01 22:35:15 +00:00
|
|
|
'conditioning_free': True, 'conditioning_free_k': 1,
|
2022-05-30 22:25:33 +00:00
|
|
|
'diffusion_schedule': 'linear', 'diffusion_type': 'from_codes_quant'}
|
2022-06-01 22:35:15 +00:00
|
|
|
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 560, 'device': 'cuda', 'opt': {}}
|
2022-04-20 06:28:03 +00:00
|
|
|
eval = MusicDiffusionFid(diffusion, opt_eval, env)
|
|
|
|
print(eval.perform_eval())
|