music fid updates

This commit is contained in:
James Betker 2022-05-08 18:49:39 -06:00
parent 7812c23c7a
commit 1177c35dec
3 changed files with 60 additions and 32 deletions

View File

@ -546,6 +546,43 @@ class GaussianDiffusion:
yield out
img = out["sample"]
def p_sample_loop_with_guidance(
self,
model,
guidance_input,
mask,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
):
if device is None:
device = next(model.parameters()).device
shape = guidance_input.shape
if noise is None:
noise = th.randn(*shape, device=device)
indices = list(range(self.num_timesteps))[::-1]
img = noise
for i in tqdm(indices):
t = th.tensor([i] * shape[0], device=device)
with th.no_grad():
out = self.p_sample(
model,
img,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
)
model_driven_out = out["sample"] * mask.logical_not()
guidance_driven_out = self.q_sample(guidance_input, t, noise=noise) * mask
img = model_driven_out + guidance_driven_out
return img
def ddim_sample(
self,
model,

View File

@ -327,7 +327,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_music_gap_filler.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_contrastive_audio.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)

View File

@ -2,31 +2,23 @@ import os
import os.path as osp
from glob import glob
import numpy as np
import torch
import torchaudio
import torchvision
from pytorch_fid.fid_score import calculate_frechet_distance
from torch import distributed
from tqdm import tqdm
from transformers import Wav2Vec2ForCTC
import torch.nn.functional as F
import numpy as np
import trainer.eval.evaluator as evaluator
from data.audio.paired_voice_audio_dataset import load_tsv_aligned_codes
from data.audio.unsupervised_audio_dataset import load_audio
from data.audio.voice_tokenizer import VoiceBpeTokenizer
from models.audio.music.unet_diffusion_waveform_gen import DiffusionWaveformGen
from models.clip.contrastive_audio import ContrastiveAudio
from models.clip.mel_text_clip import MelTextCLIP
from models.audio.tts.tacotron2 import text_to_sequence
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
from models.diffusion.respace import space_timesteps, SpacedDiffusion
from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel, load_speech_dvae, \
convert_mel_to_codes, load_univnet_vocoder, wav_to_univnet_mel
from trainer.injectors.audio_injectors import denormalize_mel, TorchMelSpectrogramInjector, pixel_shuffle_1d, \
normalize_mel
from utils.util import ceil_multiple, opt_get, load_model_from_config, pad_or_truncate
from utils.util import opt_get, load_model_from_config
class MusicDiffusionFid(evaluator.Evaluator):
@ -88,18 +80,20 @@ class MusicDiffusionFid(evaluator.Evaluator):
model_kwargs={'aligned_conditioning': mel})
gen = pixel_shuffle_1d(gen, 16)
return gen, real_resampled, sample_rate
return gen, real_resampled, self.spec_fn({'in': gen})['out'], mel, sample_rate
def gen_freq_gap(self, mel, band_range=(130,150)):
def gen_freq_gap(self, mel, band_range=(60,100)):
gap_start, gap_end = band_range
mel[:, gap_start:gap_end] = 0
return mel
mask = torch.ones_like(mel)
mask[:, gap_start:gap_end] = 0
return mel * mask, mask
def gen_time_gap(self, mel):
mel[:, :, 22050*5:22050*6] = 0
return mel
mask = torch.ones_like(mel)
mask[:, :, 86*4:86*6] = 0
return mel * mask, mask
def perform_diffusion_gap_fill(self, audio, sample_rate=22050, band_range=(130,150)):
def perform_diffusion_gap_fill(self, audio, sample_rate=22050):
if sample_rate != sample_rate:
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
else:
@ -109,15 +103,10 @@ class MusicDiffusionFid(evaluator.Evaluator):
# Fetch the MEL and mask out the requested bands.
mel = self.spec_fn({'in': audio})['out']
mel = normalize_mel(mel)
mel = self.gap_gen_fn(mel)
output_shape = (1, mel.shape[1], mel.shape[2])
mel, mask = self.gap_gen_fn(mel)
# Repair the MEL with the given model.
spec = self.diffuser.p_sample_loop(self.model, output_shape, noise=torch.zeros(*output_shape, device=audio.device),
model_kwargs={'truth': mel})
import torchvision
torchvision.utils.save_image((spec.unsqueeze(1) + 1) / 2, 'gen.png')
torchvision.utils.save_image((mel.unsqueeze(1) + 1) / 2, 'mel.png')
spec = self.diffuser.p_sample_loop_with_guidance(self.model, mel, mask, model_kwargs={'truth': mel})
spec = denormalize_mel(spec)
# Re-convert the resulting MEL back into audio using the spectrogram decoder.
@ -128,7 +117,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
model_kwargs={'aligned_conditioning': spec})
gen = pixel_shuffle_1d(gen, 16)
return gen, real_resampled, sample_rate
return gen, real_resampled, normalize_mel(spec), mel, sample_rate
def project(self, sample, sample_rate):
sample = torchaudio.functional.resample(sample, sample_rate, 22050)
@ -164,21 +153,23 @@ class MusicDiffusionFid(evaluator.Evaluator):
for i in tqdm(list(range(0, len(self.data), self.skip))):
path = self.data[i + self.env['rank']]
audio = load_audio(path, 22050).to(self.dev)
audio = audio[:, :22050*5]
sample, ref, sample_rate = self.diffusion_fn(audio)
#audio = audio[:, :22050*8]
sample, ref, sample_mel, ref_mel, sample_rate = self.diffusion_fn(audio)
gen_projections.append(self.project(sample, sample_rate).cpu()) # Store on CPU to avoid wasting GPU memory.
real_projections.append(self.project(ref, sample_rate).cpu())
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_gen.wav"), sample.squeeze(0).cpu(), sample_rate)
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_real.wav"), ref.cpu(), sample_rate)
torchvision.utils.save_image((sample_mel.unsqueeze(1) + 1) / 2, os.path.join(save_path, f"{self.env['rank']}_{i}_gen_mel.png"))
torchvision.utils.save_image((ref_mel.unsqueeze(1) + 1) / 2, os.path.join(save_path, f"{self.env['rank']}_{i}_real_mel.png"))
gen_projections = torch.stack(gen_projections, dim=0)
real_projections = torch.stack(real_projections, dim=0)
frechet_distance = torch.tensor(self.compute_frechet_distance(gen_projections, real_projections), device=self.env['device'])
if distributed.is_initialized() and distributed.get_world_size() > 1:
distributed.all_reduce(frechet_distance)
frechet_distance = frechet_distance / distributed.get_world_size()\
frechet_distance = frechet_distance / distributed.get_world_size()
self.model.train()
torch.set_rng_state(rng_state)
@ -193,10 +184,10 @@ class MusicDiffusionFid(evaluator.Evaluator):
if __name__ == '__main__':
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_gap_filler.yml', 'generator',
also_load_savepoint=False,
load_path='X:\\dlas\\experiments\\train_music_gap_filler\\models\\14000_generator.pth').cuda()
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 500,
load_path='X:\\dlas\\experiments\\train_music_gap_filler2\\models\\20500_generator_ema.pth').cuda()
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 50,
'conditioning_free': False, 'conditioning_free_k': 1,
'diffusion_schedule': 'linear', 'diffusion_type': 'gap_fill_freq'}
'diffusion_schedule': 'linear', 'diffusion_type': 'gap_fill_time'}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 2, 'device': 'cuda', 'opt': {}}
eval = MusicDiffusionFid(diffusion, opt_eval, env)
print(eval.perform_eval())