diff --git a/codes/models/diffusion/gaussian_diffusion.py b/codes/models/diffusion/gaussian_diffusion.py index 3a3d7752..4a63a899 100644 --- a/codes/models/diffusion/gaussian_diffusion.py +++ b/codes/models/diffusion/gaussian_diffusion.py @@ -546,6 +546,43 @@ class GaussianDiffusion: yield out img = out["sample"] + def p_sample_loop_with_guidance( + self, + model, + guidance_input, + mask, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + ): + if device is None: + device = next(model.parameters()).device + shape = guidance_input.shape + if noise is None: + noise = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + img = noise + for i in tqdm(indices): + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + model_driven_out = out["sample"] * mask.logical_not() + guidance_driven_out = self.q_sample(guidance_input, t, noise=noise) * mask + img = model_driven_out + guidance_driven_out + return img + def ddim_sample( self, model, diff --git a/codes/train.py b/codes/train.py index b4b76d6a..e1377bb4 100644 --- a/codes/train.py +++ b/codes/train.py @@ -327,7 +327,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_music_gap_filler.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_contrastive_audio.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() opt = option.parse(args.opt, is_train=True) diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py index ec789a9f..da8c439e 100644 --- a/codes/trainer/eval/music_diffusion_fid.py +++ b/codes/trainer/eval/music_diffusion_fid.py @@ -2,31 +2,23 @@ import os import os.path as osp from glob import glob +import numpy as np import torch import torchaudio import torchvision from pytorch_fid.fid_score import calculate_frechet_distance from torch import distributed from tqdm import tqdm -from transformers import Wav2Vec2ForCTC -import torch.nn.functional as F -import numpy as np import trainer.eval.evaluator as evaluator -from data.audio.paired_voice_audio_dataset import load_tsv_aligned_codes from data.audio.unsupervised_audio_dataset import load_audio -from data.audio.voice_tokenizer import VoiceBpeTokenizer from models.audio.music.unet_diffusion_waveform_gen import DiffusionWaveformGen from models.clip.contrastive_audio import ContrastiveAudio -from models.clip.mel_text_clip import MelTextCLIP -from models.audio.tts.tacotron2 import text_to_sequence from models.diffusion.gaussian_diffusion import get_named_beta_schedule from models.diffusion.respace import space_timesteps, SpacedDiffusion -from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel, load_speech_dvae, \ - convert_mel_to_codes, load_univnet_vocoder, wav_to_univnet_mel from trainer.injectors.audio_injectors import denormalize_mel, TorchMelSpectrogramInjector, pixel_shuffle_1d, \ normalize_mel -from utils.util import ceil_multiple, opt_get, load_model_from_config, pad_or_truncate +from utils.util import opt_get, load_model_from_config class MusicDiffusionFid(evaluator.Evaluator): @@ -88,18 +80,20 @@ class MusicDiffusionFid(evaluator.Evaluator): model_kwargs={'aligned_conditioning': mel}) gen = pixel_shuffle_1d(gen, 16) - return gen, real_resampled, sample_rate + return gen, real_resampled, self.spec_fn({'in': gen})['out'], mel, sample_rate - def gen_freq_gap(self, mel, band_range=(130,150)): + def gen_freq_gap(self, mel, band_range=(60,100)): gap_start, gap_end = band_range - mel[:, gap_start:gap_end] = 0 - return mel + mask = torch.ones_like(mel) + mask[:, gap_start:gap_end] = 0 + return mel * mask, mask def gen_time_gap(self, mel): - mel[:, :, 22050*5:22050*6] = 0 - return mel + mask = torch.ones_like(mel) + mask[:, :, 86*4:86*6] = 0 + return mel * mask, mask - def perform_diffusion_gap_fill(self, audio, sample_rate=22050, band_range=(130,150)): + def perform_diffusion_gap_fill(self, audio, sample_rate=22050): if sample_rate != sample_rate: real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0) else: @@ -109,15 +103,10 @@ class MusicDiffusionFid(evaluator.Evaluator): # Fetch the MEL and mask out the requested bands. mel = self.spec_fn({'in': audio})['out'] mel = normalize_mel(mel) - mel = self.gap_gen_fn(mel) - output_shape = (1, mel.shape[1], mel.shape[2]) + mel, mask = self.gap_gen_fn(mel) # Repair the MEL with the given model. - spec = self.diffuser.p_sample_loop(self.model, output_shape, noise=torch.zeros(*output_shape, device=audio.device), - model_kwargs={'truth': mel}) - import torchvision - torchvision.utils.save_image((spec.unsqueeze(1) + 1) / 2, 'gen.png') - torchvision.utils.save_image((mel.unsqueeze(1) + 1) / 2, 'mel.png') + spec = self.diffuser.p_sample_loop_with_guidance(self.model, mel, mask, model_kwargs={'truth': mel}) spec = denormalize_mel(spec) # Re-convert the resulting MEL back into audio using the spectrogram decoder. @@ -128,7 +117,7 @@ class MusicDiffusionFid(evaluator.Evaluator): model_kwargs={'aligned_conditioning': spec}) gen = pixel_shuffle_1d(gen, 16) - return gen, real_resampled, sample_rate + return gen, real_resampled, normalize_mel(spec), mel, sample_rate def project(self, sample, sample_rate): sample = torchaudio.functional.resample(sample, sample_rate, 22050) @@ -164,21 +153,23 @@ class MusicDiffusionFid(evaluator.Evaluator): for i in tqdm(list(range(0, len(self.data), self.skip))): path = self.data[i + self.env['rank']] audio = load_audio(path, 22050).to(self.dev) - audio = audio[:, :22050*5] - sample, ref, sample_rate = self.diffusion_fn(audio) + #audio = audio[:, :22050*8] + sample, ref, sample_mel, ref_mel, sample_rate = self.diffusion_fn(audio) gen_projections.append(self.project(sample, sample_rate).cpu()) # Store on CPU to avoid wasting GPU memory. real_projections.append(self.project(ref, sample_rate).cpu()) torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_gen.wav"), sample.squeeze(0).cpu(), sample_rate) torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_real.wav"), ref.cpu(), sample_rate) + torchvision.utils.save_image((sample_mel.unsqueeze(1) + 1) / 2, os.path.join(save_path, f"{self.env['rank']}_{i}_gen_mel.png")) + torchvision.utils.save_image((ref_mel.unsqueeze(1) + 1) / 2, os.path.join(save_path, f"{self.env['rank']}_{i}_real_mel.png")) gen_projections = torch.stack(gen_projections, dim=0) real_projections = torch.stack(real_projections, dim=0) frechet_distance = torch.tensor(self.compute_frechet_distance(gen_projections, real_projections), device=self.env['device']) if distributed.is_initialized() and distributed.get_world_size() > 1: distributed.all_reduce(frechet_distance) - frechet_distance = frechet_distance / distributed.get_world_size()\ + frechet_distance = frechet_distance / distributed.get_world_size() self.model.train() torch.set_rng_state(rng_state) @@ -193,10 +184,10 @@ class MusicDiffusionFid(evaluator.Evaluator): if __name__ == '__main__': diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_gap_filler.yml', 'generator', also_load_savepoint=False, - load_path='X:\\dlas\\experiments\\train_music_gap_filler\\models\\14000_generator.pth').cuda() - opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 500, + load_path='X:\\dlas\\experiments\\train_music_gap_filler2\\models\\20500_generator_ema.pth').cuda() + opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 50, 'conditioning_free': False, 'conditioning_free_k': 1, - 'diffusion_schedule': 'linear', 'diffusion_type': 'gap_fill_freq'} + 'diffusion_schedule': 'linear', 'diffusion_type': 'gap_fill_time'} env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 2, 'device': 'cuda', 'opt': {}} eval = MusicDiffusionFid(diffusion, opt_eval, env) print(eval.perform_eval())