music fid updates
This commit is contained in:
parent
7812c23c7a
commit
1177c35dec
|
@ -546,6 +546,43 @@ class GaussianDiffusion:
|
|||
yield out
|
||||
img = out["sample"]
|
||||
|
||||
def p_sample_loop_with_guidance(
|
||||
self,
|
||||
model,
|
||||
guidance_input,
|
||||
mask,
|
||||
noise=None,
|
||||
clip_denoised=True,
|
||||
denoised_fn=None,
|
||||
cond_fn=None,
|
||||
model_kwargs=None,
|
||||
device=None,
|
||||
):
|
||||
if device is None:
|
||||
device = next(model.parameters()).device
|
||||
shape = guidance_input.shape
|
||||
if noise is None:
|
||||
noise = th.randn(*shape, device=device)
|
||||
indices = list(range(self.num_timesteps))[::-1]
|
||||
|
||||
img = noise
|
||||
for i in tqdm(indices):
|
||||
t = th.tensor([i] * shape[0], device=device)
|
||||
with th.no_grad():
|
||||
out = self.p_sample(
|
||||
model,
|
||||
img,
|
||||
t,
|
||||
clip_denoised=clip_denoised,
|
||||
denoised_fn=denoised_fn,
|
||||
cond_fn=cond_fn,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
model_driven_out = out["sample"] * mask.logical_not()
|
||||
guidance_driven_out = self.q_sample(guidance_input, t, noise=noise) * mask
|
||||
img = model_driven_out + guidance_driven_out
|
||||
return img
|
||||
|
||||
def ddim_sample(
|
||||
self,
|
||||
model,
|
||||
|
|
|
@ -327,7 +327,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_music_gap_filler.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_contrastive_audio.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
|
|
|
@ -2,31 +2,23 @@ import os
|
|||
import os.path as osp
|
||||
from glob import glob
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
import torchvision
|
||||
from pytorch_fid.fid_score import calculate_frechet_distance
|
||||
from torch import distributed
|
||||
from tqdm import tqdm
|
||||
from transformers import Wav2Vec2ForCTC
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
import trainer.eval.evaluator as evaluator
|
||||
from data.audio.paired_voice_audio_dataset import load_tsv_aligned_codes
|
||||
from data.audio.unsupervised_audio_dataset import load_audio
|
||||
from data.audio.voice_tokenizer import VoiceBpeTokenizer
|
||||
from models.audio.music.unet_diffusion_waveform_gen import DiffusionWaveformGen
|
||||
from models.clip.contrastive_audio import ContrastiveAudio
|
||||
from models.clip.mel_text_clip import MelTextCLIP
|
||||
from models.audio.tts.tacotron2 import text_to_sequence
|
||||
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
|
||||
from models.diffusion.respace import space_timesteps, SpacedDiffusion
|
||||
from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel, load_speech_dvae, \
|
||||
convert_mel_to_codes, load_univnet_vocoder, wav_to_univnet_mel
|
||||
from trainer.injectors.audio_injectors import denormalize_mel, TorchMelSpectrogramInjector, pixel_shuffle_1d, \
|
||||
normalize_mel
|
||||
from utils.util import ceil_multiple, opt_get, load_model_from_config, pad_or_truncate
|
||||
from utils.util import opt_get, load_model_from_config
|
||||
|
||||
|
||||
class MusicDiffusionFid(evaluator.Evaluator):
|
||||
|
@ -88,18 +80,20 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
model_kwargs={'aligned_conditioning': mel})
|
||||
gen = pixel_shuffle_1d(gen, 16)
|
||||
|
||||
return gen, real_resampled, sample_rate
|
||||
return gen, real_resampled, self.spec_fn({'in': gen})['out'], mel, sample_rate
|
||||
|
||||
def gen_freq_gap(self, mel, band_range=(130,150)):
|
||||
def gen_freq_gap(self, mel, band_range=(60,100)):
|
||||
gap_start, gap_end = band_range
|
||||
mel[:, gap_start:gap_end] = 0
|
||||
return mel
|
||||
mask = torch.ones_like(mel)
|
||||
mask[:, gap_start:gap_end] = 0
|
||||
return mel * mask, mask
|
||||
|
||||
def gen_time_gap(self, mel):
|
||||
mel[:, :, 22050*5:22050*6] = 0
|
||||
return mel
|
||||
mask = torch.ones_like(mel)
|
||||
mask[:, :, 86*4:86*6] = 0
|
||||
return mel * mask, mask
|
||||
|
||||
def perform_diffusion_gap_fill(self, audio, sample_rate=22050, band_range=(130,150)):
|
||||
def perform_diffusion_gap_fill(self, audio, sample_rate=22050):
|
||||
if sample_rate != sample_rate:
|
||||
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
||||
else:
|
||||
|
@ -109,15 +103,10 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
# Fetch the MEL and mask out the requested bands.
|
||||
mel = self.spec_fn({'in': audio})['out']
|
||||
mel = normalize_mel(mel)
|
||||
mel = self.gap_gen_fn(mel)
|
||||
output_shape = (1, mel.shape[1], mel.shape[2])
|
||||
mel, mask = self.gap_gen_fn(mel)
|
||||
|
||||
# Repair the MEL with the given model.
|
||||
spec = self.diffuser.p_sample_loop(self.model, output_shape, noise=torch.zeros(*output_shape, device=audio.device),
|
||||
model_kwargs={'truth': mel})
|
||||
import torchvision
|
||||
torchvision.utils.save_image((spec.unsqueeze(1) + 1) / 2, 'gen.png')
|
||||
torchvision.utils.save_image((mel.unsqueeze(1) + 1) / 2, 'mel.png')
|
||||
spec = self.diffuser.p_sample_loop_with_guidance(self.model, mel, mask, model_kwargs={'truth': mel})
|
||||
spec = denormalize_mel(spec)
|
||||
|
||||
# Re-convert the resulting MEL back into audio using the spectrogram decoder.
|
||||
|
@ -128,7 +117,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
model_kwargs={'aligned_conditioning': spec})
|
||||
gen = pixel_shuffle_1d(gen, 16)
|
||||
|
||||
return gen, real_resampled, sample_rate
|
||||
return gen, real_resampled, normalize_mel(spec), mel, sample_rate
|
||||
|
||||
def project(self, sample, sample_rate):
|
||||
sample = torchaudio.functional.resample(sample, sample_rate, 22050)
|
||||
|
@ -164,21 +153,23 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
for i in tqdm(list(range(0, len(self.data), self.skip))):
|
||||
path = self.data[i + self.env['rank']]
|
||||
audio = load_audio(path, 22050).to(self.dev)
|
||||
audio = audio[:, :22050*5]
|
||||
sample, ref, sample_rate = self.diffusion_fn(audio)
|
||||
#audio = audio[:, :22050*8]
|
||||
sample, ref, sample_mel, ref_mel, sample_rate = self.diffusion_fn(audio)
|
||||
|
||||
gen_projections.append(self.project(sample, sample_rate).cpu()) # Store on CPU to avoid wasting GPU memory.
|
||||
real_projections.append(self.project(ref, sample_rate).cpu())
|
||||
|
||||
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_gen.wav"), sample.squeeze(0).cpu(), sample_rate)
|
||||
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_real.wav"), ref.cpu(), sample_rate)
|
||||
torchvision.utils.save_image((sample_mel.unsqueeze(1) + 1) / 2, os.path.join(save_path, f"{self.env['rank']}_{i}_gen_mel.png"))
|
||||
torchvision.utils.save_image((ref_mel.unsqueeze(1) + 1) / 2, os.path.join(save_path, f"{self.env['rank']}_{i}_real_mel.png"))
|
||||
gen_projections = torch.stack(gen_projections, dim=0)
|
||||
real_projections = torch.stack(real_projections, dim=0)
|
||||
frechet_distance = torch.tensor(self.compute_frechet_distance(gen_projections, real_projections), device=self.env['device'])
|
||||
|
||||
if distributed.is_initialized() and distributed.get_world_size() > 1:
|
||||
distributed.all_reduce(frechet_distance)
|
||||
frechet_distance = frechet_distance / distributed.get_world_size()\
|
||||
frechet_distance = frechet_distance / distributed.get_world_size()
|
||||
|
||||
self.model.train()
|
||||
torch.set_rng_state(rng_state)
|
||||
|
@ -193,10 +184,10 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
if __name__ == '__main__':
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_gap_filler.yml', 'generator',
|
||||
also_load_savepoint=False,
|
||||
load_path='X:\\dlas\\experiments\\train_music_gap_filler\\models\\14000_generator.pth').cuda()
|
||||
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 500,
|
||||
load_path='X:\\dlas\\experiments\\train_music_gap_filler2\\models\\20500_generator_ema.pth').cuda()
|
||||
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 50,
|
||||
'conditioning_free': False, 'conditioning_free_k': 1,
|
||||
'diffusion_schedule': 'linear', 'diffusion_type': 'gap_fill_freq'}
|
||||
'diffusion_schedule': 'linear', 'diffusion_type': 'gap_fill_time'}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 2, 'device': 'cuda', 'opt': {}}
|
||||
eval = MusicDiffusionFid(diffusion, opt_eval, env)
|
||||
print(eval.perform_eval())
|
||||
|
|
Loading…
Reference in New Issue
Block a user