music fid updates

This commit is contained in:
James Betker 2022-05-08 18:49:39 -06:00
parent 7812c23c7a
commit 1177c35dec
3 changed files with 60 additions and 32 deletions

View File

@ -546,6 +546,43 @@ class GaussianDiffusion:
yield out yield out
img = out["sample"] img = out["sample"]
def p_sample_loop_with_guidance(
self,
model,
guidance_input,
mask,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
):
if device is None:
device = next(model.parameters()).device
shape = guidance_input.shape
if noise is None:
noise = th.randn(*shape, device=device)
indices = list(range(self.num_timesteps))[::-1]
img = noise
for i in tqdm(indices):
t = th.tensor([i] * shape[0], device=device)
with th.no_grad():
out = self.p_sample(
model,
img,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
)
model_driven_out = out["sample"] * mask.logical_not()
guidance_driven_out = self.q_sample(guidance_input, t, noise=noise) * mask
img = model_driven_out + guidance_driven_out
return img
def ddim_sample( def ddim_sample(
self, self,
model, model,

View File

@ -327,7 +327,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_music_gap_filler.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_contrastive_audio.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args() args = parser.parse_args()
opt = option.parse(args.opt, is_train=True) opt = option.parse(args.opt, is_train=True)

View File

@ -2,31 +2,23 @@ import os
import os.path as osp import os.path as osp
from glob import glob from glob import glob
import numpy as np
import torch import torch
import torchaudio import torchaudio
import torchvision import torchvision
from pytorch_fid.fid_score import calculate_frechet_distance from pytorch_fid.fid_score import calculate_frechet_distance
from torch import distributed from torch import distributed
from tqdm import tqdm from tqdm import tqdm
from transformers import Wav2Vec2ForCTC
import torch.nn.functional as F
import numpy as np
import trainer.eval.evaluator as evaluator import trainer.eval.evaluator as evaluator
from data.audio.paired_voice_audio_dataset import load_tsv_aligned_codes
from data.audio.unsupervised_audio_dataset import load_audio from data.audio.unsupervised_audio_dataset import load_audio
from data.audio.voice_tokenizer import VoiceBpeTokenizer
from models.audio.music.unet_diffusion_waveform_gen import DiffusionWaveformGen from models.audio.music.unet_diffusion_waveform_gen import DiffusionWaveformGen
from models.clip.contrastive_audio import ContrastiveAudio from models.clip.contrastive_audio import ContrastiveAudio
from models.clip.mel_text_clip import MelTextCLIP
from models.audio.tts.tacotron2 import text_to_sequence
from models.diffusion.gaussian_diffusion import get_named_beta_schedule from models.diffusion.gaussian_diffusion import get_named_beta_schedule
from models.diffusion.respace import space_timesteps, SpacedDiffusion from models.diffusion.respace import space_timesteps, SpacedDiffusion
from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel, load_speech_dvae, \
convert_mel_to_codes, load_univnet_vocoder, wav_to_univnet_mel
from trainer.injectors.audio_injectors import denormalize_mel, TorchMelSpectrogramInjector, pixel_shuffle_1d, \ from trainer.injectors.audio_injectors import denormalize_mel, TorchMelSpectrogramInjector, pixel_shuffle_1d, \
normalize_mel normalize_mel
from utils.util import ceil_multiple, opt_get, load_model_from_config, pad_or_truncate from utils.util import opt_get, load_model_from_config
class MusicDiffusionFid(evaluator.Evaluator): class MusicDiffusionFid(evaluator.Evaluator):
@ -88,18 +80,20 @@ class MusicDiffusionFid(evaluator.Evaluator):
model_kwargs={'aligned_conditioning': mel}) model_kwargs={'aligned_conditioning': mel})
gen = pixel_shuffle_1d(gen, 16) gen = pixel_shuffle_1d(gen, 16)
return gen, real_resampled, sample_rate return gen, real_resampled, self.spec_fn({'in': gen})['out'], mel, sample_rate
def gen_freq_gap(self, mel, band_range=(130,150)): def gen_freq_gap(self, mel, band_range=(60,100)):
gap_start, gap_end = band_range gap_start, gap_end = band_range
mel[:, gap_start:gap_end] = 0 mask = torch.ones_like(mel)
return mel mask[:, gap_start:gap_end] = 0
return mel * mask, mask
def gen_time_gap(self, mel): def gen_time_gap(self, mel):
mel[:, :, 22050*5:22050*6] = 0 mask = torch.ones_like(mel)
return mel mask[:, :, 86*4:86*6] = 0
return mel * mask, mask
def perform_diffusion_gap_fill(self, audio, sample_rate=22050, band_range=(130,150)): def perform_diffusion_gap_fill(self, audio, sample_rate=22050):
if sample_rate != sample_rate: if sample_rate != sample_rate:
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0) real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
else: else:
@ -109,15 +103,10 @@ class MusicDiffusionFid(evaluator.Evaluator):
# Fetch the MEL and mask out the requested bands. # Fetch the MEL and mask out the requested bands.
mel = self.spec_fn({'in': audio})['out'] mel = self.spec_fn({'in': audio})['out']
mel = normalize_mel(mel) mel = normalize_mel(mel)
mel = self.gap_gen_fn(mel) mel, mask = self.gap_gen_fn(mel)
output_shape = (1, mel.shape[1], mel.shape[2])
# Repair the MEL with the given model. # Repair the MEL with the given model.
spec = self.diffuser.p_sample_loop(self.model, output_shape, noise=torch.zeros(*output_shape, device=audio.device), spec = self.diffuser.p_sample_loop_with_guidance(self.model, mel, mask, model_kwargs={'truth': mel})
model_kwargs={'truth': mel})
import torchvision
torchvision.utils.save_image((spec.unsqueeze(1) + 1) / 2, 'gen.png')
torchvision.utils.save_image((mel.unsqueeze(1) + 1) / 2, 'mel.png')
spec = denormalize_mel(spec) spec = denormalize_mel(spec)
# Re-convert the resulting MEL back into audio using the spectrogram decoder. # Re-convert the resulting MEL back into audio using the spectrogram decoder.
@ -128,7 +117,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
model_kwargs={'aligned_conditioning': spec}) model_kwargs={'aligned_conditioning': spec})
gen = pixel_shuffle_1d(gen, 16) gen = pixel_shuffle_1d(gen, 16)
return gen, real_resampled, sample_rate return gen, real_resampled, normalize_mel(spec), mel, sample_rate
def project(self, sample, sample_rate): def project(self, sample, sample_rate):
sample = torchaudio.functional.resample(sample, sample_rate, 22050) sample = torchaudio.functional.resample(sample, sample_rate, 22050)
@ -164,21 +153,23 @@ class MusicDiffusionFid(evaluator.Evaluator):
for i in tqdm(list(range(0, len(self.data), self.skip))): for i in tqdm(list(range(0, len(self.data), self.skip))):
path = self.data[i + self.env['rank']] path = self.data[i + self.env['rank']]
audio = load_audio(path, 22050).to(self.dev) audio = load_audio(path, 22050).to(self.dev)
audio = audio[:, :22050*5] #audio = audio[:, :22050*8]
sample, ref, sample_rate = self.diffusion_fn(audio) sample, ref, sample_mel, ref_mel, sample_rate = self.diffusion_fn(audio)
gen_projections.append(self.project(sample, sample_rate).cpu()) # Store on CPU to avoid wasting GPU memory. gen_projections.append(self.project(sample, sample_rate).cpu()) # Store on CPU to avoid wasting GPU memory.
real_projections.append(self.project(ref, sample_rate).cpu()) real_projections.append(self.project(ref, sample_rate).cpu())
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_gen.wav"), sample.squeeze(0).cpu(), sample_rate) torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_gen.wav"), sample.squeeze(0).cpu(), sample_rate)
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_real.wav"), ref.cpu(), sample_rate) torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_real.wav"), ref.cpu(), sample_rate)
torchvision.utils.save_image((sample_mel.unsqueeze(1) + 1) / 2, os.path.join(save_path, f"{self.env['rank']}_{i}_gen_mel.png"))
torchvision.utils.save_image((ref_mel.unsqueeze(1) + 1) / 2, os.path.join(save_path, f"{self.env['rank']}_{i}_real_mel.png"))
gen_projections = torch.stack(gen_projections, dim=0) gen_projections = torch.stack(gen_projections, dim=0)
real_projections = torch.stack(real_projections, dim=0) real_projections = torch.stack(real_projections, dim=0)
frechet_distance = torch.tensor(self.compute_frechet_distance(gen_projections, real_projections), device=self.env['device']) frechet_distance = torch.tensor(self.compute_frechet_distance(gen_projections, real_projections), device=self.env['device'])
if distributed.is_initialized() and distributed.get_world_size() > 1: if distributed.is_initialized() and distributed.get_world_size() > 1:
distributed.all_reduce(frechet_distance) distributed.all_reduce(frechet_distance)
frechet_distance = frechet_distance / distributed.get_world_size()\ frechet_distance = frechet_distance / distributed.get_world_size()
self.model.train() self.model.train()
torch.set_rng_state(rng_state) torch.set_rng_state(rng_state)
@ -193,10 +184,10 @@ class MusicDiffusionFid(evaluator.Evaluator):
if __name__ == '__main__': if __name__ == '__main__':
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_gap_filler.yml', 'generator', diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_gap_filler.yml', 'generator',
also_load_savepoint=False, also_load_savepoint=False,
load_path='X:\\dlas\\experiments\\train_music_gap_filler\\models\\14000_generator.pth').cuda() load_path='X:\\dlas\\experiments\\train_music_gap_filler2\\models\\20500_generator_ema.pth').cuda()
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 500, opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 50,
'conditioning_free': False, 'conditioning_free_k': 1, 'conditioning_free': False, 'conditioning_free_k': 1,
'diffusion_schedule': 'linear', 'diffusion_type': 'gap_fill_freq'} 'diffusion_schedule': 'linear', 'diffusion_type': 'gap_fill_time'}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 2, 'device': 'cuda', 'opt': {}} env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 2, 'device': 'cuda', 'opt': {}}
eval = MusicDiffusionFid(diffusion, opt_eval, env) eval = MusicDiffusionFid(diffusion, opt_eval, env)
print(eval.perform_eval()) print(eval.perform_eval())