import functools import os import os.path as osp from glob import glob from random import shuffle from time import time import numpy as np import torch import torchaudio import torchvision from pytorch_fid.fid_score import calculate_frechet_distance from torch import distributed from tqdm import tqdm import trainer.eval.evaluator as evaluator from data.audio.unsupervised_audio_dataset import load_audio from models.audio.mel2vec import ContrastiveTrainingWrapper from models.audio.music.unet_diffusion_waveform_gen import DiffusionWaveformGen from models.clip.contrastive_audio import ContrastiveAudio from models.diffusion.gaussian_diffusion import get_named_beta_schedule from models.diffusion.respace import space_timesteps, SpacedDiffusion from trainer.injectors.audio_injectors import denormalize_mel, TorchMelSpectrogramInjector, pixel_shuffle_1d, \ normalize_mel from utils.music_utils import get_music_codegen, get_mel2wav_model from utils.util import opt_get, load_model_from_config class MusicDiffusionFid(evaluator.Evaluator): """ Evaluator produces generate from a music diffusion model. """ def __init__(self, model, opt_eval, env): super().__init__(model, opt_eval, env, uses_all_ddp=True) self.real_path = opt_eval['path'] self.data = self.load_data(self.real_path) if distributed.is_initialized() and distributed.get_world_size() > 1: self.skip = distributed.get_world_size() # One batch element per GPU. else: self.skip = 1 diffusion_steps = opt_get(opt_eval, ['diffusion_steps'], 50) diffusion_schedule = opt_get(env['opt'], ['steps', 'generator', 'injectors', 'diffusion', 'beta_schedule', 'schedule_name'], None) if diffusion_schedule is None: print("Unable to infer diffusion schedule from master options. Getting it from eval (or guessing).") diffusion_schedule = opt_get(opt_eval, ['diffusion_schedule'], 'linear') conditioning_free_diffusion_enabled = opt_get(opt_eval, ['conditioning_free'], False) conditioning_free_k = opt_get(opt_eval, ['conditioning_free_k'], 1) self.diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [diffusion_steps]), model_mean_type='epsilon', model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule(diffusion_schedule, 4000), conditioning_free=conditioning_free_diffusion_enabled, conditioning_free_k=conditioning_free_k) self.dev = self.env['device'] mode = opt_get(opt_eval, ['diffusion_type'], 'tts') self.spec_decoder = get_mel2wav_model() self.projector = ContrastiveAudio(model_dim=512, transformer_heads=8, dropout=0, encoder_depth=8, mel_channels=256) self.projector.load_state_dict(torch.load('../experiments/music_eval_projector.pth', map_location=torch.device('cpu'))) self.local_modules = {'spec_decoder': self.spec_decoder, 'projector': self.projector} if mode == 'spec_decode': self.diffusion_fn = self.perform_diffusion_spec_decode elif 'from_codes' == mode: self.diffusion_fn = self.perform_diffusion_from_codes self.local_modules['codegen'] = get_music_codegen() elif 'from_codes_quant' == mode: self.diffusion_fn = self.perform_diffusion_from_codes_quant elif 'partial_from_codes_quant' == mode: self.diffusion_fn = functools.partial(self.perform_partial_diffusion_from_codes_quant, partial_low=opt_eval['partial_low'], partial_high=opt_eval['partial_high']) elif 'from_codes_quant_gradual_decode' == mode: self.diffusion_fn = self.perform_diffusion_from_codes_quant_gradual_decode self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 11000, 'filter_length': 16000, 'normalize': True, 'in': 'in', 'out': 'out'}, {}) def load_data(self, path): return list(glob(f'{path}/*.wav')) def perform_diffusion_spec_decode(self, audio, sample_rate=22050): if sample_rate != sample_rate: real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0) else: real_resampled = audio audio = audio.unsqueeze(0) output_shape = (1, 16, audio.shape[-1] // 16) mel = self.spec_fn({'in': audio})['out'] gen = self.diffuser.p_sample_loop(self.model, output_shape, model_kwargs={'aligned_conditioning': mel}) gen = pixel_shuffle_1d(gen, 16) return gen, real_resampled, normalize_mel(self.spec_fn({'in': gen})['out']), normalize_mel(mel), sample_rate def perform_diffusion_from_codes(self, audio, sample_rate=22050): if sample_rate != sample_rate: real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0) else: real_resampled = audio audio = audio.unsqueeze(0) mel = self.spec_fn({'in': audio})['out'] codegen = self.local_modules['codegen'].to(mel.device) codes = codegen.get_codes(mel, project=True) mel_norm = normalize_mel(mel) gen_mel = self.diffuser.p_sample_loop(self.model, mel_norm.shape, model_kwargs={'codes': codes, 'conditioning_input': torch.zeros_like(mel_norm[:,:,:390])}) gen_mel_denorm = denormalize_mel(gen_mel) output_shape = (1,16,audio.shape[-1]//16) self.spec_decoder = self.spec_decoder.to(audio.device) gen_wav = self.diffuser.p_sample_loop(self.spec_decoder, output_shape, model_kwargs={'aligned_conditioning': gen_mel_denorm}) gen_wav = pixel_shuffle_1d(gen_wav, 16) return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate def perform_diffusion_from_codes_quant(self, audio, sample_rate=22050): if sample_rate != sample_rate: real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0) else: real_resampled = audio audio = audio.unsqueeze(0) mel = self.spec_fn({'in': audio})['out'] mel_norm = normalize_mel(mel) #def denoising_fn(x): # q9 = torch.quantile(x, q=.95, dim=-1).unsqueeze(-1) # s = q9.clamp(1, 9999999999) # x = x.clamp(-s, s) / s # return x gen_mel = self.diffuser.p_sample_loop(self.model, mel_norm.shape, #denoised_fn=denoising_fn, clip_denoised=False, model_kwargs={'truth_mel': mel, 'conditioning_input': torch.zeros_like(mel_norm[:,:,:390]), 'disable_diversity': True}) gen_mel_denorm = denormalize_mel(gen_mel) output_shape = (1,16,audio.shape[-1]//16) self.spec_decoder = self.spec_decoder.to(audio.device) gen_wav = self.diffuser.p_sample_loop(self.spec_decoder, output_shape, model_kwargs={'aligned_conditioning': gen_mel_denorm}) gen_wav = pixel_shuffle_1d(gen_wav, 16) return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate def perform_partial_diffusion_from_codes_quant(self, audio, sample_rate=22050, partial_low=0, partial_high=256): if sample_rate != sample_rate: real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0) else: real_resampled = audio audio = audio.unsqueeze(0) mel = self.spec_fn({'in': audio})['out'] mel_norm = normalize_mel(mel) mask = torch.ones_like(mel_norm) mask[:, partial_low:partial_high] = 0 # This is the channel region that the model will predict. gen_mel = self.diffuser.p_sample_loop_with_guidance(self.model, guidance_input=mel_norm, mask=mask, model_kwargs={'truth_mel': mel, 'conditioning_input': torch.zeros_like(mel_norm[:,:,:390]), 'disable_diversity': True}) gen_mel_denorm = denormalize_mel(gen_mel) output_shape = (1,16,audio.shape[-1]//16) self.spec_decoder = self.spec_decoder.to(audio.device) gen_wav = self.diffuser.p_sample_loop(self.spec_decoder, output_shape, model_kwargs={'aligned_conditioning': gen_mel_denorm}) gen_wav = pixel_shuffle_1d(gen_wav, 16) return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate def perform_diffusion_from_codes_quant_gradual_decode(self, audio, sample_rate=22050): if sample_rate != sample_rate: real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0) else: real_resampled = audio audio = audio.unsqueeze(0) mel = self.spec_fn({'in': audio})['out'] mel_norm = normalize_mel(mel) guidance = torch.zeros_like(mel_norm) mask = torch.zeros_like(mel_norm) GRADS = 4 for k in range(GRADS): gen_mel = self.diffuser.p_sample_loop_with_guidance(self.model, guidance_input=guidance, mask=mask, model_kwargs={'truth_mel': mel, 'conditioning_input': torch.zeros_like(mel_norm[:,:,:390]), 'disable_diversity': True}) pk = int(k*(mel_norm.shape[1]/GRADS)) ek = int((k+1)*(mel_norm.shape[1]/GRADS)) guidance[:, pk:ek] = gen_mel[:, pk:ek] mask[:, :ek] = 1 gen_mel_denorm = denormalize_mel(gen_mel) output_shape = (1,16,audio.shape[-1]//16) self.spec_decoder = self.spec_decoder.to(audio.device) gen_wav = self.diffuser.p_sample_loop(self.spec_decoder, output_shape, model_kwargs={'aligned_conditioning': gen_mel_denorm}) gen_wav = pixel_shuffle_1d(gen_wav, 16) return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate def project(self, sample, sample_rate): sample = torchaudio.functional.resample(sample, sample_rate, 22050) mel = self.spec_fn({'in': sample})['out'] projection = self.projector.project(mel) return projection.squeeze(0) # Getting rid of the batch dimension means it's just [hidden_dim] def compute_frechet_distance(self, proj1, proj2): # I really REALLY FUCKING HATE that this is going to numpy. Why does "pytorch_fid" operate in numpy land. WHY? proj1 = proj1.cpu().numpy() proj2 = proj2.cpu().numpy() mu1 = np.mean(proj1, axis=0) mu2 = np.mean(proj2, axis=0) sigma1 = np.cov(proj1, rowvar=False) sigma2 = np.cov(proj2, rowvar=False) try: return torch.tensor(calculate_frechet_distance(mu1, sigma1, mu2, sigma2)) except: return 0 def perform_eval(self): save_path = osp.join(self.env['base_path'], "../", "audio_eval", str(self.env["step"])) os.makedirs(save_path, exist_ok=True) self.projector = self.projector.to(self.dev) self.projector.eval() # Attempt to fix the random state as much as possible. RNG state will be restored before returning. rng_state = torch.get_rng_state() torch.manual_seed(5) self.model.eval() with torch.no_grad(): gen_projections = [] real_projections = [] for i in tqdm(list(range(0, len(self.data), self.skip))): path = self.data[(i + self.env['rank']) % len(self.data)] audio = load_audio(path, 22050).to(self.dev) audio = audio[:, :22050*10] sample, ref, sample_mel, ref_mel, sample_rate = self.diffusion_fn(audio) gen_projections.append(self.project(sample, sample_rate).cpu()) # Store on CPU to avoid wasting GPU memory. real_projections.append(self.project(ref, sample_rate).cpu()) torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_gen.wav"), sample.squeeze(0).cpu(), sample_rate) torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_real.wav"), ref.cpu(), sample_rate) torchvision.utils.save_image((sample_mel.unsqueeze(1) + 1) / 2, os.path.join(save_path, f"{self.env['rank']}_{i}_gen_mel.png")) torchvision.utils.save_image((ref_mel.unsqueeze(1) + 1) / 2, os.path.join(save_path, f"{self.env['rank']}_{i}_real_mel.png")) gen_projections = torch.stack(gen_projections, dim=0) real_projections = torch.stack(real_projections, dim=0) frechet_distance = torch.tensor(self.compute_frechet_distance(gen_projections, real_projections), device=self.env['device']) if distributed.is_initialized() and distributed.get_world_size() > 1: distributed.all_reduce(frechet_distance) frechet_distance = frechet_distance / distributed.get_world_size() self.model.train() torch.set_rng_state(rng_state) # Put modules used for evaluation back into CPU memory. for k, mod in self.local_modules.items(): self.local_modules[k] = mod.cpu() self.spec_decoder = self.spec_decoder.cpu() return {"frechet_distance": frechet_distance} if __name__ == '__main__': diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_ar_prior.yml', 'generator', also_load_savepoint=False, load_path='X:\\dlas\\experiments\\train_music_diffusion_ar_prior\\models\\22000_generator_ema.pth' ).cuda() opt_eval = {#'path': 'Y:\\split\\yt-music-eval', 'path': 'E:\\music_eval', 'diffusion_steps': 100, 'conditioning_free': False, 'conditioning_free_k': 1, 'diffusion_schedule': 'linear', 'diffusion_type': 'partial_from_codes_quant', 'partial_low': 128, 'partial_high': 192} env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 504, 'device': 'cuda', 'opt': {}} eval = MusicDiffusionFid(diffusion, opt_eval, env) print(eval.perform_eval())