music fid updates
This commit is contained in:
parent
7812c23c7a
commit
1177c35dec
|
@ -546,6 +546,43 @@ class GaussianDiffusion:
|
||||||
yield out
|
yield out
|
||||||
img = out["sample"]
|
img = out["sample"]
|
||||||
|
|
||||||
|
def p_sample_loop_with_guidance(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
guidance_input,
|
||||||
|
mask,
|
||||||
|
noise=None,
|
||||||
|
clip_denoised=True,
|
||||||
|
denoised_fn=None,
|
||||||
|
cond_fn=None,
|
||||||
|
model_kwargs=None,
|
||||||
|
device=None,
|
||||||
|
):
|
||||||
|
if device is None:
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
shape = guidance_input.shape
|
||||||
|
if noise is None:
|
||||||
|
noise = th.randn(*shape, device=device)
|
||||||
|
indices = list(range(self.num_timesteps))[::-1]
|
||||||
|
|
||||||
|
img = noise
|
||||||
|
for i in tqdm(indices):
|
||||||
|
t = th.tensor([i] * shape[0], device=device)
|
||||||
|
with th.no_grad():
|
||||||
|
out = self.p_sample(
|
||||||
|
model,
|
||||||
|
img,
|
||||||
|
t,
|
||||||
|
clip_denoised=clip_denoised,
|
||||||
|
denoised_fn=denoised_fn,
|
||||||
|
cond_fn=cond_fn,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
)
|
||||||
|
model_driven_out = out["sample"] * mask.logical_not()
|
||||||
|
guidance_driven_out = self.q_sample(guidance_input, t, noise=noise) * mask
|
||||||
|
img = model_driven_out + guidance_driven_out
|
||||||
|
return img
|
||||||
|
|
||||||
def ddim_sample(
|
def ddim_sample(
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
|
|
|
@ -327,7 +327,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_music_gap_filler.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_contrastive_audio.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
opt = option.parse(args.opt, is_train=True)
|
opt = option.parse(args.opt, is_train=True)
|
||||||
|
|
|
@ -2,31 +2,23 @@ import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
from glob import glob
|
from glob import glob
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
import torchvision
|
import torchvision
|
||||||
from pytorch_fid.fid_score import calculate_frechet_distance
|
from pytorch_fid.fid_score import calculate_frechet_distance
|
||||||
from torch import distributed
|
from torch import distributed
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import Wav2Vec2ForCTC
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import trainer.eval.evaluator as evaluator
|
import trainer.eval.evaluator as evaluator
|
||||||
from data.audio.paired_voice_audio_dataset import load_tsv_aligned_codes
|
|
||||||
from data.audio.unsupervised_audio_dataset import load_audio
|
from data.audio.unsupervised_audio_dataset import load_audio
|
||||||
from data.audio.voice_tokenizer import VoiceBpeTokenizer
|
|
||||||
from models.audio.music.unet_diffusion_waveform_gen import DiffusionWaveformGen
|
from models.audio.music.unet_diffusion_waveform_gen import DiffusionWaveformGen
|
||||||
from models.clip.contrastive_audio import ContrastiveAudio
|
from models.clip.contrastive_audio import ContrastiveAudio
|
||||||
from models.clip.mel_text_clip import MelTextCLIP
|
|
||||||
from models.audio.tts.tacotron2 import text_to_sequence
|
|
||||||
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
|
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
|
||||||
from models.diffusion.respace import space_timesteps, SpacedDiffusion
|
from models.diffusion.respace import space_timesteps, SpacedDiffusion
|
||||||
from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel, load_speech_dvae, \
|
|
||||||
convert_mel_to_codes, load_univnet_vocoder, wav_to_univnet_mel
|
|
||||||
from trainer.injectors.audio_injectors import denormalize_mel, TorchMelSpectrogramInjector, pixel_shuffle_1d, \
|
from trainer.injectors.audio_injectors import denormalize_mel, TorchMelSpectrogramInjector, pixel_shuffle_1d, \
|
||||||
normalize_mel
|
normalize_mel
|
||||||
from utils.util import ceil_multiple, opt_get, load_model_from_config, pad_or_truncate
|
from utils.util import opt_get, load_model_from_config
|
||||||
|
|
||||||
|
|
||||||
class MusicDiffusionFid(evaluator.Evaluator):
|
class MusicDiffusionFid(evaluator.Evaluator):
|
||||||
|
@ -88,18 +80,20 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
||||||
model_kwargs={'aligned_conditioning': mel})
|
model_kwargs={'aligned_conditioning': mel})
|
||||||
gen = pixel_shuffle_1d(gen, 16)
|
gen = pixel_shuffle_1d(gen, 16)
|
||||||
|
|
||||||
return gen, real_resampled, sample_rate
|
return gen, real_resampled, self.spec_fn({'in': gen})['out'], mel, sample_rate
|
||||||
|
|
||||||
def gen_freq_gap(self, mel, band_range=(130,150)):
|
def gen_freq_gap(self, mel, band_range=(60,100)):
|
||||||
gap_start, gap_end = band_range
|
gap_start, gap_end = band_range
|
||||||
mel[:, gap_start:gap_end] = 0
|
mask = torch.ones_like(mel)
|
||||||
return mel
|
mask[:, gap_start:gap_end] = 0
|
||||||
|
return mel * mask, mask
|
||||||
|
|
||||||
def gen_time_gap(self, mel):
|
def gen_time_gap(self, mel):
|
||||||
mel[:, :, 22050*5:22050*6] = 0
|
mask = torch.ones_like(mel)
|
||||||
return mel
|
mask[:, :, 86*4:86*6] = 0
|
||||||
|
return mel * mask, mask
|
||||||
|
|
||||||
def perform_diffusion_gap_fill(self, audio, sample_rate=22050, band_range=(130,150)):
|
def perform_diffusion_gap_fill(self, audio, sample_rate=22050):
|
||||||
if sample_rate != sample_rate:
|
if sample_rate != sample_rate:
|
||||||
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
||||||
else:
|
else:
|
||||||
|
@ -109,15 +103,10 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
||||||
# Fetch the MEL and mask out the requested bands.
|
# Fetch the MEL and mask out the requested bands.
|
||||||
mel = self.spec_fn({'in': audio})['out']
|
mel = self.spec_fn({'in': audio})['out']
|
||||||
mel = normalize_mel(mel)
|
mel = normalize_mel(mel)
|
||||||
mel = self.gap_gen_fn(mel)
|
mel, mask = self.gap_gen_fn(mel)
|
||||||
output_shape = (1, mel.shape[1], mel.shape[2])
|
|
||||||
|
|
||||||
# Repair the MEL with the given model.
|
# Repair the MEL with the given model.
|
||||||
spec = self.diffuser.p_sample_loop(self.model, output_shape, noise=torch.zeros(*output_shape, device=audio.device),
|
spec = self.diffuser.p_sample_loop_with_guidance(self.model, mel, mask, model_kwargs={'truth': mel})
|
||||||
model_kwargs={'truth': mel})
|
|
||||||
import torchvision
|
|
||||||
torchvision.utils.save_image((spec.unsqueeze(1) + 1) / 2, 'gen.png')
|
|
||||||
torchvision.utils.save_image((mel.unsqueeze(1) + 1) / 2, 'mel.png')
|
|
||||||
spec = denormalize_mel(spec)
|
spec = denormalize_mel(spec)
|
||||||
|
|
||||||
# Re-convert the resulting MEL back into audio using the spectrogram decoder.
|
# Re-convert the resulting MEL back into audio using the spectrogram decoder.
|
||||||
|
@ -128,7 +117,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
||||||
model_kwargs={'aligned_conditioning': spec})
|
model_kwargs={'aligned_conditioning': spec})
|
||||||
gen = pixel_shuffle_1d(gen, 16)
|
gen = pixel_shuffle_1d(gen, 16)
|
||||||
|
|
||||||
return gen, real_resampled, sample_rate
|
return gen, real_resampled, normalize_mel(spec), mel, sample_rate
|
||||||
|
|
||||||
def project(self, sample, sample_rate):
|
def project(self, sample, sample_rate):
|
||||||
sample = torchaudio.functional.resample(sample, sample_rate, 22050)
|
sample = torchaudio.functional.resample(sample, sample_rate, 22050)
|
||||||
|
@ -164,21 +153,23 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
||||||
for i in tqdm(list(range(0, len(self.data), self.skip))):
|
for i in tqdm(list(range(0, len(self.data), self.skip))):
|
||||||
path = self.data[i + self.env['rank']]
|
path = self.data[i + self.env['rank']]
|
||||||
audio = load_audio(path, 22050).to(self.dev)
|
audio = load_audio(path, 22050).to(self.dev)
|
||||||
audio = audio[:, :22050*5]
|
#audio = audio[:, :22050*8]
|
||||||
sample, ref, sample_rate = self.diffusion_fn(audio)
|
sample, ref, sample_mel, ref_mel, sample_rate = self.diffusion_fn(audio)
|
||||||
|
|
||||||
gen_projections.append(self.project(sample, sample_rate).cpu()) # Store on CPU to avoid wasting GPU memory.
|
gen_projections.append(self.project(sample, sample_rate).cpu()) # Store on CPU to avoid wasting GPU memory.
|
||||||
real_projections.append(self.project(ref, sample_rate).cpu())
|
real_projections.append(self.project(ref, sample_rate).cpu())
|
||||||
|
|
||||||
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_gen.wav"), sample.squeeze(0).cpu(), sample_rate)
|
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_gen.wav"), sample.squeeze(0).cpu(), sample_rate)
|
||||||
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_real.wav"), ref.cpu(), sample_rate)
|
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_real.wav"), ref.cpu(), sample_rate)
|
||||||
|
torchvision.utils.save_image((sample_mel.unsqueeze(1) + 1) / 2, os.path.join(save_path, f"{self.env['rank']}_{i}_gen_mel.png"))
|
||||||
|
torchvision.utils.save_image((ref_mel.unsqueeze(1) + 1) / 2, os.path.join(save_path, f"{self.env['rank']}_{i}_real_mel.png"))
|
||||||
gen_projections = torch.stack(gen_projections, dim=0)
|
gen_projections = torch.stack(gen_projections, dim=0)
|
||||||
real_projections = torch.stack(real_projections, dim=0)
|
real_projections = torch.stack(real_projections, dim=0)
|
||||||
frechet_distance = torch.tensor(self.compute_frechet_distance(gen_projections, real_projections), device=self.env['device'])
|
frechet_distance = torch.tensor(self.compute_frechet_distance(gen_projections, real_projections), device=self.env['device'])
|
||||||
|
|
||||||
if distributed.is_initialized() and distributed.get_world_size() > 1:
|
if distributed.is_initialized() and distributed.get_world_size() > 1:
|
||||||
distributed.all_reduce(frechet_distance)
|
distributed.all_reduce(frechet_distance)
|
||||||
frechet_distance = frechet_distance / distributed.get_world_size()\
|
frechet_distance = frechet_distance / distributed.get_world_size()
|
||||||
|
|
||||||
self.model.train()
|
self.model.train()
|
||||||
torch.set_rng_state(rng_state)
|
torch.set_rng_state(rng_state)
|
||||||
|
@ -193,10 +184,10 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_gap_filler.yml', 'generator',
|
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_gap_filler.yml', 'generator',
|
||||||
also_load_savepoint=False,
|
also_load_savepoint=False,
|
||||||
load_path='X:\\dlas\\experiments\\train_music_gap_filler\\models\\14000_generator.pth').cuda()
|
load_path='X:\\dlas\\experiments\\train_music_gap_filler2\\models\\20500_generator_ema.pth').cuda()
|
||||||
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 500,
|
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 50,
|
||||||
'conditioning_free': False, 'conditioning_free_k': 1,
|
'conditioning_free': False, 'conditioning_free_k': 1,
|
||||||
'diffusion_schedule': 'linear', 'diffusion_type': 'gap_fill_freq'}
|
'diffusion_schedule': 'linear', 'diffusion_type': 'gap_fill_time'}
|
||||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 2, 'device': 'cuda', 'opt': {}}
|
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 2, 'device': 'cuda', 'opt': {}}
|
||||||
eval = MusicDiffusionFid(diffusion, opt_eval, env)
|
eval = MusicDiffusionFid(diffusion, opt_eval, env)
|
||||||
print(eval.perform_eval())
|
print(eval.perform_eval())
|
||||||
|
|
Loading…
Reference in New Issue
Block a user