forked from mrq/DL-Art-School
- use a gated activation layer for both attention & convs - add a relativistic learned position bias. I believe this is similar to the T5 position encodings but it is simpler and learned - get rid of prepending to the attention matrix - this doesn't really work that well. the model eventually learns to attend one of its heads to these blocks but why not just concat if it is doing that?
361 lines
20 KiB
Python
361 lines
20 KiB
Python
import os
|
|
import os
|
|
import os.path as osp
|
|
from glob import glob
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torchaudio
|
|
import torchvision
|
|
from pytorch_fid.fid_score import calculate_frechet_distance
|
|
from torch import distributed
|
|
from tqdm import tqdm
|
|
|
|
import trainer.eval.evaluator as evaluator
|
|
from data.audio.unsupervised_audio_dataset import load_audio
|
|
from models.clip.contrastive_audio import ContrastiveAudio
|
|
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
|
|
from models.diffusion.respace import space_timesteps, SpacedDiffusion
|
|
from trainer.injectors.audio_injectors import denormalize_torch_mel, TorchMelSpectrogramInjector, pixel_shuffle_1d, \
|
|
KmeansQuantizerInjector, normalize_torch_mel
|
|
from utils.music_utils import get_music_codegen, get_mel2wav_model, get_cheater_decoder, get_cheater_encoder, \
|
|
get_mel2wav_v3_model, get_ar_prior
|
|
from utils.util import opt_get, load_model_from_config
|
|
|
|
|
|
class MusicDiffusionFid(evaluator.Evaluator):
|
|
"""
|
|
Evaluator produces generate from a music diffusion model.
|
|
"""
|
|
def __init__(self, model, opt_eval, env):
|
|
super().__init__(model, opt_eval, env, uses_all_ddp=True)
|
|
self.real_path = opt_eval['path']
|
|
self.data = self.load_data(self.real_path)
|
|
self.clip = opt_get(opt_eval, ['clip_audio'], True) # Recommend setting true for more efficient eval passes.
|
|
self.ddim = opt_get(opt_eval, ['use_ddim'], False)
|
|
self.causal = opt_get(opt_eval, ['causal'], False)
|
|
self.causal_slope = opt_get(opt_eval, ['causal_slope'], 1)
|
|
if distributed.is_initialized() and distributed.get_world_size() > 1:
|
|
self.skip = distributed.get_world_size() # One batch element per GPU.
|
|
else:
|
|
self.skip = 1
|
|
diffusion_steps = opt_get(opt_eval, ['diffusion_steps'], 50)
|
|
diffusion_schedule = opt_get(env['opt'], ['steps', 'generator', 'injectors', 'diffusion', 'beta_schedule', 'schedule_name'], None)
|
|
if diffusion_schedule is None:
|
|
print("Unable to infer diffusion schedule from master options. Getting it from eval (or guessing).")
|
|
diffusion_schedule = opt_get(opt_eval, ['diffusion_schedule'], 'linear')
|
|
conditioning_free_diffusion_enabled = opt_get(opt_eval, ['conditioning_free'], False)
|
|
conditioning_free_k = opt_get(opt_eval, ['conditioning_free_k'], 1)
|
|
self.diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [diffusion_steps]), model_mean_type='epsilon',
|
|
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule(diffusion_schedule, 4000),
|
|
conditioning_free=conditioning_free_diffusion_enabled, conditioning_free_k=conditioning_free_k)
|
|
self.spectral_diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [16 if self.ddim else 100]), model_mean_type='epsilon',
|
|
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000),
|
|
conditioning_free=False, conditioning_free_k=1)
|
|
self.dev = self.env['device']
|
|
mode = opt_get(opt_eval, ['diffusion_type'], 'spec_decode')
|
|
|
|
self.projector = ContrastiveAudio(model_dim=512, transformer_heads=8, dropout=0, encoder_depth=8, mel_channels=256)
|
|
self.projector.load_state_dict(torch.load('../experiments/music_eval_projector.pth', map_location=torch.device('cpu')))
|
|
|
|
self.local_modules = {'projector': self.projector}
|
|
if mode == 'spec_decode':
|
|
self.diffusion_fn = self.perform_diffusion_spec_decode
|
|
self.squeeze_ratio = opt_eval['squeeze_ratio']
|
|
elif 'from_codes' == mode:
|
|
self.diffusion_fn = self.perform_diffusion_from_codes
|
|
self.local_modules['codegen'] = get_music_codegen()
|
|
elif 'from_codes_quant' == mode:
|
|
self.diffusion_fn = self.perform_diffusion_from_codes_quant
|
|
elif 'cheater_gen' == mode:
|
|
self.diffusion_fn = self.perform_reconstruction_from_cheater_gen
|
|
self.local_modules['cheater_encoder'] = get_cheater_encoder()
|
|
self.local_modules['cheater_decoder'] = get_cheater_decoder()
|
|
self.cheater_decoder_diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [32]), model_mean_type='epsilon',
|
|
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000),
|
|
conditioning_free=True, conditioning_free_k=1)
|
|
self.spectral_diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [16]), model_mean_type='epsilon',
|
|
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000),
|
|
conditioning_free=False, conditioning_free_k=1)
|
|
self.spec_decoder = get_mel2wav_v3_model() # The only reason the other functions don't use v3 is because earlier models were trained with v1 and I want to keep metrics consistent.
|
|
self.local_modules['spec_decoder'] = self.spec_decoder
|
|
elif 'from_ar_prior' == mode:
|
|
self.diffusion_fn = self.perform_diffusion_from_codes_ar_prior
|
|
self.local_modules['cheater_encoder'] = get_cheater_encoder()
|
|
self.local_modules['cheater_decoder'] = get_cheater_decoder()
|
|
self.cheater_decoder_diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [32]), model_mean_type='epsilon',
|
|
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000),
|
|
conditioning_free=True, conditioning_free_k=1)
|
|
self.kmeans_inj = KmeansQuantizerInjector({'centroids': '../experiments/music_k_means_centroids.pth', 'in': 'in', 'out': 'out'}, {})
|
|
self.local_modules['ar_prior'] = get_ar_prior()
|
|
self.spec_decoder = get_mel2wav_v3_model()
|
|
self.local_modules['spec_decoder'] = self.spec_decoder
|
|
elif 'chained_sr' == mode:
|
|
self.diffusion_fn = self.perform_chained_sr
|
|
self.spec_decoder = get_mel2wav_v3_model()
|
|
self.local_modules['spec_decoder'] = self.spec_decoder
|
|
if not hasattr(self, 'spec_decoder'):
|
|
self.spec_decoder = get_mel2wav_model()
|
|
self.local_modules['spec_decoder'] = self.spec_decoder
|
|
self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 11000, 'filter_length': 16000,
|
|
'normalize': True, 'in': 'in', 'out': 'out'}, {})
|
|
|
|
def load_data(self, path):
|
|
return list(glob(f'{path}/*.wav'))
|
|
|
|
def perform_diffusion_spec_decode(self, audio, sample_rate=22050):
|
|
real_resampled = audio
|
|
audio = audio.unsqueeze(0)
|
|
output_shape = (1, self.squeeze_ratio, audio.shape[-1] // self.squeeze_ratio)
|
|
mel = self.spec_fn({'in': audio})['out']
|
|
gen = self.diffuser.p_sample_loop(self.model, output_shape,
|
|
model_kwargs={'codes': mel})
|
|
gen = pixel_shuffle_1d(gen, self.squeeze_ratio)
|
|
|
|
return gen, real_resampled, normalize_torch_mel(self.spec_fn({'in': gen})['out']), normalize_torch_mel(mel), sample_rate
|
|
|
|
def perform_diffusion_from_codes(self, audio, sample_rate=22050):
|
|
real_resampled = audio
|
|
audio = audio.unsqueeze(0)
|
|
|
|
mel = self.spec_fn({'in': audio})['out']
|
|
codegen = self.local_modules['codegen'].to(mel.device)
|
|
codes = codegen.get_codes(mel, project=True)
|
|
mel_norm = normalize_torch_mel(mel)
|
|
gen_mel = self.diffuser.p_sample_loop(self.model, mel_norm.shape,
|
|
model_kwargs={'codes': codes, 'conditioning_input': torch.zeros_like(mel_norm[:,:,:390])})
|
|
|
|
gen_mel_denorm = denormalize_torch_mel(gen_mel)
|
|
output_shape = (1,16,audio.shape[-1]//16)
|
|
self.spec_decoder = self.spec_decoder.to(audio.device)
|
|
gen_wav = self.spectral_diffuser.p_sample_loop(self.spec_decoder, output_shape,
|
|
model_kwargs={'aligned_conditioning': gen_mel_denorm})
|
|
gen_wav = pixel_shuffle_1d(gen_wav, 16)
|
|
|
|
return gen_wav, real_resampled, gen_mel, mel_norm, sample_rate
|
|
|
|
def perform_diffusion_from_codes_quant(self, audio, sample_rate=22050):
|
|
audio = audio.unsqueeze(0)
|
|
|
|
mel = self.spec_fn({'in': audio})['out']
|
|
mel_norm = normalize_torch_mel(mel)
|
|
#def denoising_fn(x):
|
|
# q9 = torch.quantile(x, q=.95, dim=-1).unsqueeze(-1)
|
|
# s = q9.clamp(1, 9999999999)
|
|
# x = x.clamp(-s, s) / s
|
|
# return x
|
|
sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop
|
|
gen_mel = sampler(self.model, mel_norm.shape, model_kwargs={'truth_mel': mel_norm}, eta=.8)
|
|
|
|
gen_mel_denorm = denormalize_torch_mel(gen_mel)
|
|
output_shape = (1,16,audio.shape[-1]//16)
|
|
self.spec_decoder = self.spec_decoder.to(audio.device)
|
|
sampler = self.spectral_diffuser.ddim_sample_loop if self.ddim else self.spectral_diffuser.p_sample_loop
|
|
gen_wav = sampler(self.spec_decoder, output_shape,
|
|
model_kwargs={'aligned_conditioning': gen_mel_denorm})
|
|
gen_wav = pixel_shuffle_1d(gen_wav, 16)
|
|
|
|
real_wav = sampler(self.spec_decoder, output_shape,
|
|
model_kwargs={'aligned_conditioning': mel})
|
|
real_wav = pixel_shuffle_1d(real_wav, 16)
|
|
|
|
return gen_wav, real_wav.squeeze(0), gen_mel, mel_norm, sample_rate
|
|
|
|
def perform_reconstruction_from_cheater_gen(self, audio, sample_rate=22050):
|
|
audio = audio.unsqueeze(0)
|
|
|
|
mel = self.spec_fn({'in': audio})['out']
|
|
mel_norm = normalize_torch_mel(mel)
|
|
cheater = self.local_modules['cheater_encoder'].to(audio.device)(mel_norm)
|
|
|
|
# 1. Generate the cheater latent using the input as a reference.
|
|
sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop
|
|
# center-pad the conditioning input (the center isn't actually used). this is hack for giving tfdpc5 a bigger working context.
|
|
cheater_padded = torch.cat([cheater[:,:,cheater.shape[-1]//2:], torch.zeros(1,256,160, device=cheater.device), cheater[:,:,:cheater.shape[-1]//2]], dim=-1)
|
|
gen_cheater = sampler(self.model, cheater.shape, progress=True,
|
|
causal=self.causal, causal_slope=self.causal_slope,
|
|
model_kwargs={'conditioning_input': cheater_padded, 'cond_start': 80})
|
|
|
|
# 2. Decode the cheater into a MEL
|
|
gen_mel = self.cheater_decoder_diffuser.ddim_sample_loop(self.local_modules['cheater_decoder'].diff.to(audio.device), (1,256,gen_cheater.shape[-1]*16), progress=True,
|
|
model_kwargs={'codes': gen_cheater.permute(0,2,1)})
|
|
|
|
# 3. And then the MEL back into a spectrogram
|
|
output_shape = (1,16,audio.shape[-1]//16)
|
|
self.spec_decoder = self.spec_decoder.to(audio.device)
|
|
gen_mel_denorm = denormalize_torch_mel(gen_mel)
|
|
gen_wav = self.spectral_diffuser.ddim_sample_loop(self.spec_decoder, output_shape,
|
|
model_kwargs={'codes': gen_mel_denorm})
|
|
gen_wav = pixel_shuffle_1d(gen_wav, 16)
|
|
|
|
real_wav = self.spectral_diffuser.ddim_sample_loop(self.spec_decoder, output_shape,
|
|
model_kwargs={'codes': mel})
|
|
real_wav = pixel_shuffle_1d(real_wav, 16)
|
|
|
|
return gen_wav, real_wav.squeeze(0), gen_mel, mel_norm, sample_rate
|
|
|
|
def perform_diffusion_from_codes_ar_prior(self, audio, sample_rate=22050):
|
|
audio = audio.unsqueeze(0)
|
|
|
|
mel = self.spec_fn({'in': audio})['out']
|
|
mel_norm = normalize_torch_mel(mel)
|
|
cheater = self.local_modules['cheater_encoder'].to(audio.device)(mel_norm)
|
|
cheater_codes = self.kmeans_inj({'in': cheater})['out']
|
|
ar_latent = self.local_modules['ar_prior'].to(audio.device)(cheater_codes, cheater, return_latent=True)
|
|
|
|
# 1. Generate the cheater latent using the input as a reference.
|
|
sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop
|
|
gen_cheater = sampler(self.model, cheater.shape, progress=True,
|
|
causal=self.causal, causal_slope=self.causal_slope,
|
|
model_kwargs={'codes': ar_latent})
|
|
|
|
# 2. Decode the cheater into a MEL
|
|
gen_mel = self.cheater_decoder_diffuser.ddim_sample_loop(self.local_modules['cheater_decoder'].diff.to(audio.device), (1,256,gen_cheater.shape[-1]*16), progress=True,
|
|
model_kwargs={'codes': gen_cheater.permute(0,2,1)})
|
|
gen_mel_denorm = denormalize_torch_mel(gen_mel)
|
|
|
|
# 3. Decode into waveform.
|
|
output_shape = (1,16,audio.shape[-1]//16)
|
|
self.spec_decoder = self.spec_decoder.to(audio.device)
|
|
gen_wav = self.spectral_diffuser.ddim_sample_loop(self.spec_decoder, output_shape, model_kwargs={'codes': gen_mel_denorm})
|
|
gen_wav = pixel_shuffle_1d(gen_wav, 16)
|
|
|
|
real_wav = self.spectral_diffuser.ddim_sample_loop(self.spec_decoder, output_shape, model_kwargs={'codes': mel})
|
|
real_wav = pixel_shuffle_1d(real_wav, 16)
|
|
|
|
return gen_wav, real_wav.squeeze(0), gen_mel, mel_norm, sample_rate
|
|
|
|
def perform_chained_sr(self, audio, sample_rate=22050):
|
|
audio = audio.unsqueeze(0)
|
|
mel = self.spec_fn({'in': audio})['out']
|
|
mel_norm = normalize_torch_mel(mel)
|
|
conditioning = mel_norm[:,:,:1200]
|
|
downsampled = F.interpolate(mel_norm, scale_factor=1/16, mode='nearest')
|
|
stage1_shape = (1, 256, downsampled.shape[-1]*4)
|
|
sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop
|
|
# Chain super-sampling using 2 stages.
|
|
stage1 = sampler(self.model, stage1_shape, model_kwargs={'resolution': torch.tensor([1], device=audio.device),
|
|
'x_prior': downsampled,
|
|
'conditioning_input': conditioning})
|
|
stage2 = sampler(self.model, mel.shape, model_kwargs={'resolution': torch.tensor([0], device=audio.device),
|
|
'x_prior': stage1,
|
|
'conditioning_input': conditioning})
|
|
# Decode into waveform.
|
|
output_shape = (1,16,audio.shape[-1]//16)
|
|
self.spec_decoder = self.spec_decoder.to(audio.device)
|
|
gen_wav = self.spectral_diffuser.ddim_sample_loop(self.spec_decoder, output_shape, model_kwargs={'codes': stage2})
|
|
gen_wav = pixel_shuffle_1d(gen_wav, 16)
|
|
|
|
real_wav = self.spectral_diffuser.ddim_sample_loop(self.spec_decoder, output_shape, model_kwargs={'codes': mel})
|
|
real_wav = pixel_shuffle_1d(real_wav, 16)
|
|
|
|
return gen_wav, real_wav.squeeze(0), stage2, mel_norm, sample_rate
|
|
|
|
def project(self, sample, sample_rate):
|
|
sample = torchaudio.functional.resample(sample, sample_rate, 22050)
|
|
mel = self.spec_fn({'in': sample})['out']
|
|
projection = self.projector.project(mel)
|
|
return projection.squeeze(0) # Getting rid of the batch dimension means it's just [hidden_dim]
|
|
|
|
def compute_frechet_distance(self, proj1, proj2):
|
|
# I really REALLY FUCKING HATE that this is going to numpy. Why does "pytorch_fid" operate in numpy land. WHY?
|
|
proj1 = proj1.cpu().numpy()
|
|
proj2 = proj2.cpu().numpy()
|
|
mu1 = np.mean(proj1, axis=0)
|
|
mu2 = np.mean(proj2, axis=0)
|
|
sigma1 = np.cov(proj1, rowvar=False)
|
|
sigma2 = np.cov(proj2, rowvar=False)
|
|
try:
|
|
return torch.tensor(calculate_frechet_distance(mu1, sigma1, mu2, sigma2))
|
|
except:
|
|
return 0
|
|
|
|
def perform_eval(self):
|
|
save_path = osp.join(self.env['base_path'], "../", "audio_eval", str(self.env["step"]))
|
|
os.makedirs(save_path, exist_ok=True)
|
|
|
|
self.projector = self.projector.to(self.dev)
|
|
self.projector.eval()
|
|
|
|
# Attempt to fix the random state as much as possible. RNG state will be restored before returning.
|
|
rng_state = torch.get_rng_state()
|
|
torch.manual_seed(5)
|
|
self.model.eval()
|
|
|
|
with torch.no_grad():
|
|
gen_projections = []
|
|
real_projections = []
|
|
for i in tqdm(list(range(0, len(self.data), self.skip))):
|
|
path = self.data[(i + self.env['rank']) % len(self.data)]
|
|
audio = load_audio(path, 22050).to(self.dev)
|
|
#audio = load_audio('C:\\Users\\James\\Music\\another_longer_sample.wav', 22050).to(self.dev) # <- hack, remove it!
|
|
#audio = audio[:, :1764000]
|
|
if self.clip:
|
|
audio = audio[:, :100000]
|
|
sample, ref, sample_mel, ref_mel, sample_rate = self.diffusion_fn(audio)
|
|
|
|
gen_projections.append(self.project(sample, sample_rate).cpu()) # Store on CPU to avoid wasting GPU memory.
|
|
real_projections.append(self.project(ref, sample_rate).cpu())
|
|
|
|
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_gen.wav"), sample.squeeze(0).cpu(), sample_rate)
|
|
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_real.wav"), ref.cpu(), sample_rate)
|
|
torchvision.utils.save_image((sample_mel.unsqueeze(1) + 1) / 2, os.path.join(save_path, f"{self.env['rank']}_{i}_gen_mel.png"))
|
|
torchvision.utils.save_image((ref_mel.unsqueeze(1) + 1) / 2, os.path.join(save_path, f"{self.env['rank']}_{i}_real_mel.png"))
|
|
gen_projections = torch.stack(gen_projections, dim=0)
|
|
real_projections = torch.stack(real_projections, dim=0)
|
|
frechet_distance = torch.tensor(self.compute_frechet_distance(gen_projections, real_projections), device=self.env['device'])
|
|
|
|
if distributed.is_initialized() and distributed.get_world_size() > 1:
|
|
distributed.all_reduce(frechet_distance)
|
|
frechet_distance = frechet_distance / distributed.get_world_size()
|
|
|
|
self.model.train()
|
|
torch.set_rng_state(rng_state)
|
|
|
|
# Put modules used for evaluation back into CPU memory.
|
|
for k, mod in self.local_modules.items():
|
|
self.local_modules[k] = mod.cpu()
|
|
self.spec_decoder = self.spec_decoder.cpu()
|
|
|
|
return {"frechet_distance": frechet_distance}
|
|
|
|
|
|
if __name__ == '__main__':
|
|
"""
|
|
# For multilevel SR:
|
|
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr.yml', 'generator',
|
|
also_load_savepoint=False, strict_load=False,
|
|
load_path='X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr\\models\\4000_generator.pth'
|
|
).cuda()
|
|
opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :)
|
|
#'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety.
|
|
'diffusion_steps': 128, # basis: 192
|
|
'conditioning_free': False, 'conditioning_free_k': 1, 'use_ddim': False, 'clip_audio': False,
|
|
'diffusion_schedule': 'cosine', 'diffusion_type': 'chained_sr',
|
|
}
|
|
"""
|
|
|
|
# For TFD+cheater trainer
|
|
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_tfd_and_cheater.yml', 'generator',
|
|
also_load_savepoint=False, strict_load=False,
|
|
load_path='X:\\dlas\\experiments\\train_music_diffusion_tfd14_and_cheater_g2\\models\\20000_generator.pth'
|
|
).cuda()
|
|
opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :)
|
|
#'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety.
|
|
'diffusion_steps': 128, # basis: 192
|
|
'conditioning_free': True, 'conditioning_free_k': 1, 'use_ddim': True, 'clip_audio': True,
|
|
'diffusion_schedule': 'linear', 'diffusion_type': 'from_codes_quant',
|
|
}
|
|
|
|
|
|
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 6, 'device': 'cuda', 'opt': {}}
|
|
eval = MusicDiffusionFid(diffusion, opt_eval, env)
|
|
fds = []
|
|
for i in range(2):
|
|
res = eval.perform_eval()
|
|
print(res)
|
|
fds.append(res['frechet_distance'])
|
|
print(fds)
|
|
|