some stuff

This commit is contained in:
James Betker 2022-05-15 21:50:54 -06:00
parent ab5acead0e
commit 8202b9f39c
3 changed files with 66 additions and 27 deletions

View File

@ -45,7 +45,7 @@ def is_wav_file(filename):
def is_audio_file(filename): def is_audio_file(filename):
AUDIO_EXTENSIONS = ['.wav', '.mp3', '.wma', 'm4b'] AUDIO_EXTENSIONS = ['.wav', '.mp3', '.wma', '.m4b', '.flac']
return any(filename.endswith(extension) for extension in AUDIO_EXTENSIONS) return any(filename.endswith(extension) for extension in AUDIO_EXTENSIONS)

View File

@ -1,10 +1,12 @@
import argparse import argparse
import torch
import torchaudio import torchaudio
from data.audio.unsupervised_audio_dataset import load_audio from data.audio.unsupervised_audio_dataset import load_audio
from scripts.audio.gen.speech_synthesis_utils import do_spectrogram_diffusion, \ from scripts.audio.gen.speech_synthesis_utils import do_spectrogram_diffusion, \
load_discrete_vocoder_diffuser, wav_to_mel, convert_mel_to_codes load_discrete_vocoder_diffuser, wav_to_mel, convert_mel_to_codes, wav_to_univnet_mel, load_univnet_vocoder
from trainer.injectors.audio_injectors import denormalize_mel
from utils.audio import plot_spectrogram from utils.audio import plot_spectrogram
from utils.util import load_model_from_config from utils.util import load_model_from_config
@ -24,28 +26,30 @@ def roundtrip_vocoding(dvae, vocoder, diffuser, clip, cond=None, plot_spec=False
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-codes_file', type=str, help='Which discretes to decode. Should be a path to a pytorch pickle that simply contains the codes.')
parser.add_argument('-cond_file', type=str, help='Path to the input audio file.')
parser.add_argument('-opt', type=str, help='Path to options YAML file used to train the diffusion model', parser.add_argument('-opt', type=str, help='Path to options YAML file used to train the diffusion model',
default='X:\\dlas\\experiments\\train_diffusion_vocoder_22k_level.yml') default='X:\\dlas\\experiments\\train_diffusion_tts_mel_flat0\\last_train.yml')
parser.add_argument('-diffusion_model_name', type=str, help='Name of the diffusion model in opt.', default='generator') parser.add_argument('-diffusion_model_name', type=str, help='Name of the diffusion model in opt.', default='generator')
parser.add_argument('-diffusion_model_path', type=str, help='Diffusion model checkpoint to load.', default='X:\\dlas\\experiments\\train_diffusion_vocoder_22k_level\\models\\2500_generator.pth') parser.add_argument('-diffusion_model_path', type=str, help='Diffusion model checkpoint to load.', default='X:\\dlas\\experiments\\train_diffusion_tts_mel_flat0\\models\\114000_generator_ema.pth')
parser.add_argument('-dvae_model_name', type=str, help='Name of the DVAE model in opt.', default='dvae')
parser.add_argument('-input_file', type=str, help='Path to the input audio file.', default='Y:\\clips\\books1\\3_dchha04 Romancing The Tribes\\00036.wav')
parser.add_argument('-cond', type=str, help='Path to the conditioning input audio file.', default='Y:\\clips\\books1\\3042_18_Holden__000000000\\00037.wav')
args = parser.parse_args() args = parser.parse_args()
print("Loading DVAE..")
dvae = load_model_from_config(args.opt, args.dvae_model_name)
print("Loading Diffusion Model..")
diffusion = load_model_from_config(args.opt, args.diffusion_model_name, also_load_savepoint=False, load_path=args.diffusion_model_path)
print("Loading data..") print("Loading data..")
diffuser = load_discrete_vocoder_diffuser() codes = torch.load(args.codes_file)
inp = load_audio(args.input_file, 22050).cuda() conds = load_audio(args.cond_file, 24000)
cond = inp if args.cond is None else load_audio(args.cond, 22050) conds = conds[:,:102400]
if cond.shape[-1] > 44100+10000: cond_mel = wav_to_univnet_mel(conds.to('cuda'), do_normalization=False)
cond = cond[:,10000:54100] output_shape = (1,100,codes.shape[-1]*4)
cond = cond.cuda()
print("Loading Diffusion Model..")
diffusion = load_model_from_config(args.opt, args.diffusion_model_name, also_load_savepoint=False, load_path=args.diffusion_model_path, strict_load=False).cuda().eval()
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=50, schedule='linear', enable_conditioning_free_guidance=True, conditioning_free_k=1)
vocoder = load_univnet_vocoder().cuda()
with torch.no_grad():
print("Performing inference..") print("Performing inference..")
roundtripped = roundtrip_vocoding(dvae, diffusion, diffuser, inp, cond).cpu() for i in range(codes.shape[0]):
torchaudio.save('roundtrip_vocoded_output.wav', roundtripped.squeeze(0), 22050) gen_mel = diffuser.p_sample_loop(diffusion, output_shape, model_kwargs={'aligned_conditioning': codes[i].unsqueeze(0), 'conditioning_input': cond_mel})
gen_mel = denormalize_mel(gen_mel)
genWav = vocoder.inference(gen_mel)
torchaudio.save(f'vocoded_{i}.wav', genWav.cpu().squeeze(0), 24000)

View File

@ -1,6 +1,7 @@
import os import os
import os.path as osp import os.path as osp
from glob import glob from glob import glob
from random import shuffle
import numpy as np import numpy as np
import torch import torch
@ -63,6 +64,8 @@ class MusicDiffusionFid(evaluator.Evaluator):
self.gap_gen_fn = self.gen_freq_gap self.gap_gen_fn = self.gen_freq_gap
else: else:
self.gap_gen_fn = self.gen_time_gap self.gap_gen_fn = self.gen_time_gap
elif 'rerender' in mode:
self.diffusion_fn = self.perform_rerender
self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 22000, 'normalize': True, 'in': 'in', 'out': 'out'}, {}) self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 22000, 'normalize': True, 'in': 'in', 'out': 'out'}, {})
def load_data(self, path): def load_data(self, path):
@ -80,7 +83,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
model_kwargs={'aligned_conditioning': mel}) model_kwargs={'aligned_conditioning': mel})
gen = pixel_shuffle_1d(gen, 16) gen = pixel_shuffle_1d(gen, 16)
return gen, real_resampled, self.spec_fn({'in': gen})['out'], mel, sample_rate return gen, real_resampled, normalize_mel(self.spec_fn({'in': gen})['out']), normalize_mel(mel), sample_rate
def gen_freq_gap(self, mel, band_range=(60,100)): def gen_freq_gap(self, mel, band_range=(60,100)):
gap_start, gap_end = band_range gap_start, gap_end = band_range
@ -119,6 +122,38 @@ class MusicDiffusionFid(evaluator.Evaluator):
return gen, real_resampled, normalize_mel(spec), mel, sample_rate return gen, real_resampled, normalize_mel(spec), mel, sample_rate
def perform_rerender(self, audio, sample_rate=22050):
if sample_rate != sample_rate:
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
else:
real_resampled = audio
audio = audio.unsqueeze(0)
# Fetch the MEL and mask out the requested bands.
mel = self.spec_fn({'in': audio})['out']
mel = normalize_mel(mel)
segments = [(0,10),(10,25),(25,45),(45,60),(60,80),(80,100),(100,130),(130,170),(170,210),(210,256)]
shuffle(segments)
spec = mel
for i, segment in enumerate(segments):
mel, mask = self.gen_freq_gap(mel, band_range=segment)
# Repair the MEL with the given model.
spec = self.diffuser.p_sample_loop_with_guidance(self.model, spec, mask, model_kwargs={'truth': spec})
torchvision.utils.save_image((spec.unsqueeze(1) + 1) / 2, f"{i}_rerender.png")
spec = denormalize_mel(spec)
# Re-convert the resulting MEL back into audio using the spectrogram decoder.
output_shape = (1, 16, audio.shape[-1] // 16)
self.spec_decoder = self.spec_decoder.to(audio.device)
# Cool fact: we can re-use the diffuser for the spectrogram diffuser since it has the same parametrization.
gen = self.diffuser.p_sample_loop(self.spec_decoder, output_shape, noise=torch.zeros(*output_shape, device=audio.device),
model_kwargs={'aligned_conditioning': spec})
gen = pixel_shuffle_1d(gen, 16)
return gen, real_resampled, normalize_mel(spec), mel, sample_rate
def project(self, sample, sample_rate): def project(self, sample, sample_rate):
sample = torchaudio.functional.resample(sample, sample_rate, 22050) sample = torchaudio.functional.resample(sample, sample_rate, 22050)
mel = self.spec_fn({'in': sample})['out'] mel = self.spec_fn({'in': sample})['out']
@ -182,12 +217,12 @@ class MusicDiffusionFid(evaluator.Evaluator):
if __name__ == '__main__': if __name__ == '__main__':
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_gap_filler.yml', 'generator', diffusion = load_model_from_config('D:\\dlas\\options\\train_music_waveform_gen3.yml', 'generator',
also_load_savepoint=False, also_load_savepoint=False,
load_path='X:\\dlas\\experiments\\train_music_gap_filler2\\models\\20500_generator_ema.pth').cuda() load_path='D:\\dlas\\experiments\\train_music_waveform_gen\\models\\59000_generator_ema.pth').cuda()
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 50, opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 400,
'conditioning_free': False, 'conditioning_free_k': 1, 'conditioning_free': False, 'conditioning_free_k': 1,
'diffusion_schedule': 'linear', 'diffusion_type': 'gap_fill_time'} 'diffusion_schedule': 'linear', 'diffusion_type': 'spec_decode'}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 3, 'device': 'cuda', 'opt': {}} env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 20, 'device': 'cuda', 'opt': {}}
eval = MusicDiffusionFid(diffusion, opt_eval, env) eval = MusicDiffusionFid(diffusion, opt_eval, env)
print(eval.perform_eval()) print(eval.perform_eval())