updates to audio_diffusion_fid
This commit is contained in:
parent
9c6f776980
commit
1e3a8554a1
|
@ -5,16 +5,28 @@ import torch
|
|||
|
||||
from data.audio.unsupervised_audio_dataset import load_audio
|
||||
from data.util import find_files_of_type, is_audio_file
|
||||
from models.audio.vocoders.univnet.generator import UnivNetGenerator
|
||||
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
|
||||
from models.diffusion.respace import SpacedDiffusion, space_timesteps
|
||||
from trainer.injectors.audio_injectors import TorchMelSpectrogramInjector
|
||||
from trainer.injectors.audio_injectors import TorchMelSpectrogramInjector, MelSpectrogramInjector
|
||||
from utils.audio import plot_spectrogram
|
||||
from utils.util import load_model_from_config
|
||||
|
||||
|
||||
def load_speech_dvae():
|
||||
return load_model_from_config("../experiments/train_diffusion_vocoder_22k_level.yml",
|
||||
dvae = load_model_from_config("../experiments/train_diffusion_vocoder_22k_level.yml",
|
||||
"dvae").cpu()
|
||||
dvae.eval()
|
||||
return dvae
|
||||
|
||||
|
||||
def load_univnet_vocoder():
|
||||
model = UnivNetGenerator()
|
||||
sd = torch.load('univnet_c32_pretrained_libri.pt')
|
||||
model.load_state_dict(sd)
|
||||
model = model.cpu()
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def wav_to_mel(wav, mel_norms_file='../experiments/clips_mel_norms.pth'):
|
||||
|
@ -24,6 +36,14 @@ def wav_to_mel(wav, mel_norms_file='../experiments/clips_mel_norms.pth'):
|
|||
return TorchMelSpectrogramInjector({'in': 'wav', 'out': 'mel', 'mel_norm_file': mel_norms_file},{})({'wav': wav})['mel']
|
||||
|
||||
|
||||
def wav_to_univnet_mel(wav):
|
||||
"""
|
||||
Converts an audio clip into a MEL tensor that the univnet vocoder knows how to decode.
|
||||
"""
|
||||
return MelSpectrogramInjector({'in': 'wav', 'out': 'mel', 'sampling_rate': 24000,
|
||||
'n_mel_channels': 100, 'mel_fmax': 12000},{})({'wav': wav})['mel']
|
||||
|
||||
|
||||
def convert_mel_to_codes(dvae_model, mel):
|
||||
"""
|
||||
Converts an audio clip into discrete codes.
|
||||
|
|
|
@ -15,7 +15,7 @@ from data.audio.unsupervised_audio_dataset import load_audio
|
|||
from models.clip.mel_text_clip import MelTextCLIP
|
||||
from models.audio.tts.tacotron2 import text_to_sequence
|
||||
from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel, load_speech_dvae, \
|
||||
convert_mel_to_codes
|
||||
convert_mel_to_codes, load_univnet_vocoder, wav_to_univnet_mel
|
||||
from utils.util import ceil_multiple, opt_get
|
||||
|
||||
|
||||
|
@ -43,16 +43,19 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
conditioning_free_k=conditioning_free_k)
|
||||
self.dev = self.env['device']
|
||||
mode = opt_get(opt_eval, ['diffusion_type'], 'tts')
|
||||
self.local_modules = {}
|
||||
if mode == 'tts':
|
||||
self.diffusion_fn = self.perform_diffusion_tts
|
||||
elif mode == 'original_vocoder':
|
||||
self.dvae = load_speech_dvae().to(self.env['device'])
|
||||
self.dvae.eval()
|
||||
self.local_modules['dvae'] = load_speech_dvae().cpu()
|
||||
self.diffusion_fn = self.perform_original_diffusion_vocoder
|
||||
elif mode == 'vocoder':
|
||||
self.dvae = load_speech_dvae().to(self.env['device'])
|
||||
self.dvae.eval()
|
||||
self.local_modules['dvae'] = load_speech_dvae().cpu()
|
||||
self.diffusion_fn = self.perform_diffusion_vocoder
|
||||
elif mode == 'tts9_mel':
|
||||
self.local_modules['dvae'] = load_speech_dvae().cpu()
|
||||
self.local_modules['vocoder'] = load_univnet_vocoder().cpu()
|
||||
self.diffusion_fn = self.perform_diffusion_tts9_mel_from_codes
|
||||
|
||||
def perform_diffusion_tts(self, audio, codes, text, sample_rate=5500):
|
||||
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
||||
|
@ -71,8 +74,8 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
|
||||
def perform_original_diffusion_vocoder(self, audio, codes, text, sample_rate=11025):
|
||||
mel = wav_to_mel(audio)
|
||||
mel_codes = convert_mel_to_codes(self.dvae, mel)
|
||||
back_to_mel = self.dvae.decode(mel_codes)[0]
|
||||
mel_codes = convert_mel_to_codes(self.local_modules['dvae'], mel)
|
||||
back_to_mel = self.local_modules['dvae'].decode(mel_codes)[0]
|
||||
orig_audio = audio
|
||||
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
||||
|
||||
|
@ -96,7 +99,7 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
|
||||
def perform_diffusion_vocoder(self, audio, codes, text, sample_rate=5500):
|
||||
mel = wav_to_mel(audio)
|
||||
mel_codes = convert_mel_to_codes(self.dvae, mel)
|
||||
mel_codes = convert_mel_to_codes(self.local_modules['dvae'], mel)
|
||||
text_codes = text_to_sequence(text)
|
||||
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
||||
|
||||
|
@ -115,11 +118,12 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
return gen, real_resampled, sample_rate
|
||||
|
||||
|
||||
def perform_diffusion_tts9_from_codes(self, audio, codes, text, sample_rate=5500):
|
||||
def perform_diffusion_tts9_mel_from_codes(self, audio, codes, text):
|
||||
SAMPLE_RATE = 24000
|
||||
mel = wav_to_mel(audio)
|
||||
mel_codes = convert_mel_to_codes(self.dvae, mel)
|
||||
text_codes = text_to_sequence(text)
|
||||
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
||||
mel_codes = convert_mel_to_codes(self.local_modules['dvae'], mel)
|
||||
real_resampled = torchaudio.functional.resample(audio, 22050, SAMPLE_RATE).unsqueeze(0)
|
||||
univnet_mel = wav_to_univnet_mel(audio) # to be used for a conditioning input
|
||||
|
||||
output_size = real_resampled.shape[-1]
|
||||
aligned_codes_compression_factor = output_size // mel_codes.shape[-1]
|
||||
|
@ -129,11 +133,12 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
if padding_needed_for_codes > 0:
|
||||
mel_codes = F.pad(mel_codes, (0, padding_needed_for_codes))
|
||||
output_shape = (1, 1, padded_size)
|
||||
gen = self.diffuser.p_sample_loop(self.model, output_shape,
|
||||
model_kwargs={'tokens': mel_codes,
|
||||
'conditioning_input': audio.unsqueeze(0),
|
||||
'unaligned_input': torch.tensor(text_codes, device=audio.device).unsqueeze(0)})
|
||||
return gen, real_resampled, sample_rate
|
||||
gen_mel = self.diffuser.p_sample_loop(self.model, output_shape,
|
||||
model_kwargs={'aligned_conditioning': mel_codes,
|
||||
'conditioning_input': univnet_mel})
|
||||
|
||||
gen_wav = self.local_modules['vocoder'](gen_mel)
|
||||
return gen_wav, real_resampled, SAMPLE_RATE
|
||||
|
||||
def load_projector(self):
|
||||
"""
|
||||
|
@ -187,11 +192,11 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
|
||||
projector = self.load_projector().to(self.env['device'])
|
||||
projector.eval()
|
||||
if hasattr(self, 'dvae'):
|
||||
self.dvae = self.dvae.to(self.env['device'])
|
||||
|
||||
w2v = self.load_w2v().to(self.env['device'])
|
||||
w2v.eval()
|
||||
for k, mod in self.local_modules.items():
|
||||
self.local_modules[k] = mod.to(self.env['device'])
|
||||
|
||||
# Attempt to fix the random state as much as possible. RNG state will be restored before returning.
|
||||
rng_state = torch.get_rng_state()
|
||||
|
@ -226,10 +231,12 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
intelligibility_loss = intelligibility_loss / distributed.get_world_size()
|
||||
|
||||
self.model.train()
|
||||
if hasattr(self, 'dvae'):
|
||||
self.dvae = self.dvae.to('cpu')
|
||||
torch.set_rng_state(rng_state)
|
||||
|
||||
# Put modules used for evaluation back into CPU memory.
|
||||
for k, mod in self.local_modules.items():
|
||||
self.local_modules[k] = mod.cpu()
|
||||
|
||||
return {"frechet_distance": frechet_distance, "intelligibility_loss": intelligibility_loss}
|
||||
|
||||
"""
|
||||
|
@ -250,11 +257,12 @@ if __name__ == '__main__':
|
|||
if __name__ == '__main__':
|
||||
from utils.util import load_model_from_config
|
||||
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_vocoder_clips_from_dvae_archived_r3_b256_conditioning\\config.yml', 'generator',
|
||||
also_load_savepoint=False, load_path='X:\\dlas\\experiments\\train_diffusion_vocoder_clips_from_dvae_archived_r3_b256_conditioning\\models\\80800_generator_ema.pth').cuda()
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts9.yml', 'generator',
|
||||
also_load_savepoint=False,
|
||||
load_path='X:\\dlas\\experiments\\train_diffusion_tts9\\models\\7500_generator_ema.pth').cuda()
|
||||
opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 100,
|
||||
'conditioning_free': False, 'conditioning_free_k': 1,
|
||||
'diffusion_schedule': 'linear', 'diffusion_type': 'original_vocoder'}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 4, 'device': 'cuda', 'opt': {}}
|
||||
'diffusion_schedule': 'linear', 'diffusion_type': 'tts9_mel'}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 555, 'device': 'cuda', 'opt': {}}
|
||||
eval = AudioDiffusionFid(diffusion, opt_eval, env)
|
||||
print(eval.perform_eval())
|
||||
|
|
Loading…
Reference in New Issue
Block a user