new bounds for MEL normalization and multi-resolution SR in MDF

pull/2/head
James Betker 2022-07-19 11:11:46 +07:00
parent eecb534e66
commit da9e47ca0e
4 changed files with 57 additions and 19 deletions

@ -432,8 +432,8 @@ def inference_tfdpc5_with_cheater():
spectral_diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [32]), model_mean_type='epsilon',
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000),
conditioning_free=True, conditioning_free_k=1)
from trainer.injectors.audio_injectors import denormalize_mel
gen_mel_denorm = denormalize_mel(gen_mel)
from trainer.injectors.audio_injectors import
gen_mel_denorm = denormalize_torch_mel(gen_mel)
output_shape = (1,16,gen_mel_denorm.shape[-1]*256//16)
gen_wav = spectral_diffuser.ddim_sample_loop(m2w, output_shape, model_kwargs={'codes': gen_mel_denorm})
from trainer.injectors.audio_injectors import pixel_shuffle_1d

@ -9,7 +9,7 @@ from tqdm import tqdm
from models.audio.music.tfdpc_v5 import TransformerDiffusionWithPointConditioning
from utils.music_utils import get_cheater_decoder, get_mel2wav_v3_model
from utils.util import load_audio
from trainer.injectors.audio_injectors import TorchMelSpectrogramInjector, denormalize_mel, pixel_shuffle_1d
from trainer.injectors.audio_injectors import TorchMelSpectrogramInjector, denormalize_torch_mel, pixel_shuffle_1d
from trainer.injectors.audio_injectors import MusicCheaterLatentInjector
from models.diffusion.respace import SpacedDiffusion
from models.diffusion.respace import space_timesteps
@ -67,7 +67,7 @@ def join_music_with_cheaters(clip1_cheater, clip2_cheater, results_dir):
model_kwargs={'codes': chunk_cheater.permute(0, 2, 1)})
torchvision.utils.save_image((gen_mel + 1) / 2, f'{results_dir}/mel_{i}.png')
gen_mel_denorm = denormalize_mel(gen_mel)
gen_mel_denorm = denormalize_torch_mel(gen_mel)
output_shape = (1, 16, gen_mel_denorm.shape[-1] * 256 // 16)
wav = spectral_diffuser.ddim_sample_loop(m2w, output_shape, progress=True,
model_kwargs={'codes': gen_mel_denorm})

@ -12,6 +12,7 @@ import torchvision
from pytorch_fid.fid_score import calculate_frechet_distance
from torch import distributed
from tqdm import tqdm
import torch.nn.functional as F
import trainer.eval.evaluator as evaluator
from data.audio.unsupervised_audio_dataset import load_audio
@ -20,7 +21,7 @@ from models.audio.music.unet_diffusion_waveform_gen import DiffusionWaveformGen
from models.clip.contrastive_audio import ContrastiveAudio
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
from models.diffusion.respace import space_timesteps, SpacedDiffusion
from trainer.injectors.audio_injectors import denormalize_mel, TorchMelSpectrogramInjector, pixel_shuffle_1d, \
from trainer.injectors.audio_injectors import denormalize_torch_mel, TorchMelSpectrogramInjector, pixel_shuffle_1d, \
normalize_mel, KmeansQuantizerInjector
from utils.music_utils import get_music_codegen, get_mel2wav_model, get_cheater_decoder, get_cheater_encoder, \
get_mel2wav_v3_model, get_ar_prior
@ -92,6 +93,10 @@ class MusicDiffusionFid(evaluator.Evaluator):
self.local_modules['ar_prior'] = get_ar_prior()
self.spec_decoder = get_mel2wav_v3_model()
self.local_modules['spec_decoder'] = self.spec_decoder
elif 'chained_sr' == mode:
self.diffusion_fn = self.perform_chained_sr
self.spec_decoder = get_mel2wav_v3_model()
self.local_modules['spec_decoder'] = self.spec_decoder
if not hasattr(self, 'spec_decoder'):
self.spec_decoder = get_mel2wav_model()
self.local_modules['spec_decoder'] = self.spec_decoder
@ -123,7 +128,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
gen_mel = self.diffuser.p_sample_loop(self.model, mel_norm.shape,
model_kwargs={'codes': codes, 'conditioning_input': torch.zeros_like(mel_norm[:,:,:390])})
gen_mel_denorm = denormalize_mel(gen_mel)
gen_mel_denorm = denormalize_torch_mel(gen_mel)
output_shape = (1,16,audio.shape[-1]//16)
self.spec_decoder = self.spec_decoder.to(audio.device)
gen_wav = self.spectral_diffuser.p_sample_loop(self.spec_decoder, output_shape,
@ -154,7 +159,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
# 3. And then the MEL back into a spectrogram
output_shape = (1,16,audio.shape[-1]//16)
self.spec_decoder = self.spec_decoder.to(audio.device)
gen_mel_denorm = denormalize_mel(gen_mel)
gen_mel_denorm = denormalize_torch_mel(gen_mel)
gen_wav = self.spectral_diffuser.ddim_sample_loop(self.spec_decoder, output_shape,
model_kwargs={'codes': gen_mel_denorm})
gen_wav = pixel_shuffle_1d(gen_wav, 16)
@ -183,7 +188,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
# 2. Decode the cheater into a MEL
gen_mel = self.cheater_decoder_diffuser.ddim_sample_loop(self.local_modules['cheater_decoder'].diff.to(audio.device), (1,256,gen_cheater.shape[-1]*16), progress=True,
model_kwargs={'codes': gen_cheater.permute(0,2,1)})
gen_mel_denorm = denormalize_mel(gen_mel)
gen_mel_denorm = denormalize_torch_mel(gen_mel)
# 3. Decode into waveform.
output_shape = (1,16,audio.shape[-1]//16)
@ -196,6 +201,32 @@ class MusicDiffusionFid(evaluator.Evaluator):
return gen_wav, real_wav.squeeze(0), gen_mel, mel_norm, sample_rate
def perform_chained_sr(self, audio, sample_rate=22050):
audio = audio.unsqueeze(0)
mel = self.spec_fn({'in': audio})['out']
mel_norm = normalize_mel(mel)
conditioning = mel_norm[:,:,:1200]
downsampled = F.interpolate(mel_norm, scale_factor=1/16, mode='linear', align_corners=True)
sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop
stage1_shape = (1, 256, downsampled.shape[-1]*4)
# Chain super-sampling using 2 stages.
stage1 = sampler(self.model, stage1_shape, model_kwargs={'resolution': torch.tensor([2], device=audio.device),
'x_prior': downsampled,
'conditioning_input': conditioning})
stage2 = sampler(self.model, audio.shape, model_kwargs={'resolution': torch.tensor([1], device=audio.device),
'x_prior': stage1,
'conditioning_input': conditioning})
# Decode into waveform.
output_shape = (1,16,audio.shape[-1]//16)
self.spec_decoder = self.spec_decoder.to(audio.device)
gen_wav = self.spectral_diffuser.ddim_sample_loop(self.spec_decoder, output_shape, model_kwargs={'codes': stage2})
gen_wav = pixel_shuffle_1d(gen_wav, 16)
real_wav = self.spectral_diffuser.ddim_sample_loop(self.spec_decoder, output_shape, model_kwargs={'codes': mel})
real_wav = pixel_shuffle_1d(real_wav, 16)
return gen_wav, real_wav.squeeze(0), stage2, mel_norm, sample_rate
def project(self, sample, sample_rate):
sample = torchaudio.functional.resample(sample, sample_rate, 22050)
mel = self.spec_fn({'in': sample})['out']
@ -266,19 +297,19 @@ class MusicDiffusionFid(evaluator.Evaluator):
if __name__ == '__main__':
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_tfd_and_cheater.yml', 'generator',
also_load_savepoint=False,
load_path='X:\\dlas\\experiments\\train_music_diffusion_tfd_and_cheater\\models\\93500_generator_ema.pth'
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr.yml', 'generator',
also_load_savepoint=False, strict_load=False,
load_path='X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr\\models\\12000_generator_fixed.pth'
).cuda()
opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :)
#'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety.
'diffusion_steps': 256, # basis: 192
'conditioning_free': True, 'conditioning_free_k': 1, 'use_ddim': True, 'clip_audio': True,
'diffusion_schedule': 'linear', 'diffusion_type': 'from_codes_quant',
'diffusion_steps': 64, # basis: 192
'conditioning_free': False, 'conditioning_free_k': 1, 'use_ddim': True, 'clip_audio': False,
'diffusion_schedule': 'linear', 'diffusion_type': 'chained_sr',
#'causal': True, 'causal_slope': 4,
#'partial_low': 128, 'partial_high': 192
}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 7, 'device': 'cuda', 'opt': {}}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 10, 'device': 'cuda', 'opt': {}}
eval = MusicDiffusionFid(diffusion, opt_eval, env)
fds = []
for i in range(2):

@ -9,14 +9,21 @@ from trainer.inject import Injector
from utils.music_utils import get_music_codegen
from utils.util import opt_get, load_model_from_config, pad_or_truncate
MEL_MIN = -11.512925148010254
TACOTRON_MEL_MAX = 2.3143386840820312
TACOTRON_MEL_MIN = -11.512925148010254
TORCH_MEL_MAX = 4.82
def normalize_torch_mel(mel):
return 2 * ((mel - MEL_MIN) / (TORCH_MEL_MAX - MEL_MIN)) - 1
def denormalize_torch_mel(norm_mel):
return ((norm_mel+1)/2) * (TORCH_MEL_MAX - MEL_MIN) + MEL_MIN
def normalize_mel(mel):
return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1
return 2 * ((mel - MEL_MIN) / (TACOTRON_MEL_MAX - MEL_MIN)) - 1
def denormalize_mel(norm_mel):
return ((norm_mel+1)/2)*(TACOTRON_MEL_MAX-TACOTRON_MEL_MIN)+TACOTRON_MEL_MIN
return ((norm_mel+1)/2) * (TACOTRON_MEL_MAX - MEL_MIN) + MEL_MIN
class MelSpectrogramInjector(Injector):
def __init__(self, opt, env):
@ -83,7 +90,7 @@ class TorchMelSpectrogramInjector(Injector):
self.mel_norms = self.mel_norms.to(mel.device)
mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1)
if self.true_norm:
mel = normalize_mel(mel)
mel = normalize_torch_mel(mel)
return {self.output: mel}