updates to audio_diffusion_fid
This commit is contained in:
parent
9c6f776980
commit
1e3a8554a1
|
@ -5,16 +5,28 @@ import torch
|
||||||
|
|
||||||
from data.audio.unsupervised_audio_dataset import load_audio
|
from data.audio.unsupervised_audio_dataset import load_audio
|
||||||
from data.util import find_files_of_type, is_audio_file
|
from data.util import find_files_of_type, is_audio_file
|
||||||
|
from models.audio.vocoders.univnet.generator import UnivNetGenerator
|
||||||
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
|
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
|
||||||
from models.diffusion.respace import SpacedDiffusion, space_timesteps
|
from models.diffusion.respace import SpacedDiffusion, space_timesteps
|
||||||
from trainer.injectors.audio_injectors import TorchMelSpectrogramInjector
|
from trainer.injectors.audio_injectors import TorchMelSpectrogramInjector, MelSpectrogramInjector
|
||||||
from utils.audio import plot_spectrogram
|
from utils.audio import plot_spectrogram
|
||||||
from utils.util import load_model_from_config
|
from utils.util import load_model_from_config
|
||||||
|
|
||||||
|
|
||||||
def load_speech_dvae():
|
def load_speech_dvae():
|
||||||
return load_model_from_config("../experiments/train_diffusion_vocoder_22k_level.yml",
|
dvae = load_model_from_config("../experiments/train_diffusion_vocoder_22k_level.yml",
|
||||||
"dvae").cpu()
|
"dvae").cpu()
|
||||||
|
dvae.eval()
|
||||||
|
return dvae
|
||||||
|
|
||||||
|
|
||||||
|
def load_univnet_vocoder():
|
||||||
|
model = UnivNetGenerator()
|
||||||
|
sd = torch.load('univnet_c32_pretrained_libri.pt')
|
||||||
|
model.load_state_dict(sd)
|
||||||
|
model = model.cpu()
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def wav_to_mel(wav, mel_norms_file='../experiments/clips_mel_norms.pth'):
|
def wav_to_mel(wav, mel_norms_file='../experiments/clips_mel_norms.pth'):
|
||||||
|
@ -24,6 +36,14 @@ def wav_to_mel(wav, mel_norms_file='../experiments/clips_mel_norms.pth'):
|
||||||
return TorchMelSpectrogramInjector({'in': 'wav', 'out': 'mel', 'mel_norm_file': mel_norms_file},{})({'wav': wav})['mel']
|
return TorchMelSpectrogramInjector({'in': 'wav', 'out': 'mel', 'mel_norm_file': mel_norms_file},{})({'wav': wav})['mel']
|
||||||
|
|
||||||
|
|
||||||
|
def wav_to_univnet_mel(wav):
|
||||||
|
"""
|
||||||
|
Converts an audio clip into a MEL tensor that the univnet vocoder knows how to decode.
|
||||||
|
"""
|
||||||
|
return MelSpectrogramInjector({'in': 'wav', 'out': 'mel', 'sampling_rate': 24000,
|
||||||
|
'n_mel_channels': 100, 'mel_fmax': 12000},{})({'wav': wav})['mel']
|
||||||
|
|
||||||
|
|
||||||
def convert_mel_to_codes(dvae_model, mel):
|
def convert_mel_to_codes(dvae_model, mel):
|
||||||
"""
|
"""
|
||||||
Converts an audio clip into discrete codes.
|
Converts an audio clip into discrete codes.
|
||||||
|
|
|
@ -15,7 +15,7 @@ from data.audio.unsupervised_audio_dataset import load_audio
|
||||||
from models.clip.mel_text_clip import MelTextCLIP
|
from models.clip.mel_text_clip import MelTextCLIP
|
||||||
from models.audio.tts.tacotron2 import text_to_sequence
|
from models.audio.tts.tacotron2 import text_to_sequence
|
||||||
from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel, load_speech_dvae, \
|
from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel, load_speech_dvae, \
|
||||||
convert_mel_to_codes
|
convert_mel_to_codes, load_univnet_vocoder, wav_to_univnet_mel
|
||||||
from utils.util import ceil_multiple, opt_get
|
from utils.util import ceil_multiple, opt_get
|
||||||
|
|
||||||
|
|
||||||
|
@ -43,16 +43,19 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
||||||
conditioning_free_k=conditioning_free_k)
|
conditioning_free_k=conditioning_free_k)
|
||||||
self.dev = self.env['device']
|
self.dev = self.env['device']
|
||||||
mode = opt_get(opt_eval, ['diffusion_type'], 'tts')
|
mode = opt_get(opt_eval, ['diffusion_type'], 'tts')
|
||||||
|
self.local_modules = {}
|
||||||
if mode == 'tts':
|
if mode == 'tts':
|
||||||
self.diffusion_fn = self.perform_diffusion_tts
|
self.diffusion_fn = self.perform_diffusion_tts
|
||||||
elif mode == 'original_vocoder':
|
elif mode == 'original_vocoder':
|
||||||
self.dvae = load_speech_dvae().to(self.env['device'])
|
self.local_modules['dvae'] = load_speech_dvae().cpu()
|
||||||
self.dvae.eval()
|
|
||||||
self.diffusion_fn = self.perform_original_diffusion_vocoder
|
self.diffusion_fn = self.perform_original_diffusion_vocoder
|
||||||
elif mode == 'vocoder':
|
elif mode == 'vocoder':
|
||||||
self.dvae = load_speech_dvae().to(self.env['device'])
|
self.local_modules['dvae'] = load_speech_dvae().cpu()
|
||||||
self.dvae.eval()
|
|
||||||
self.diffusion_fn = self.perform_diffusion_vocoder
|
self.diffusion_fn = self.perform_diffusion_vocoder
|
||||||
|
elif mode == 'tts9_mel':
|
||||||
|
self.local_modules['dvae'] = load_speech_dvae().cpu()
|
||||||
|
self.local_modules['vocoder'] = load_univnet_vocoder().cpu()
|
||||||
|
self.diffusion_fn = self.perform_diffusion_tts9_mel_from_codes
|
||||||
|
|
||||||
def perform_diffusion_tts(self, audio, codes, text, sample_rate=5500):
|
def perform_diffusion_tts(self, audio, codes, text, sample_rate=5500):
|
||||||
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
||||||
|
@ -71,8 +74,8 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
||||||
|
|
||||||
def perform_original_diffusion_vocoder(self, audio, codes, text, sample_rate=11025):
|
def perform_original_diffusion_vocoder(self, audio, codes, text, sample_rate=11025):
|
||||||
mel = wav_to_mel(audio)
|
mel = wav_to_mel(audio)
|
||||||
mel_codes = convert_mel_to_codes(self.dvae, mel)
|
mel_codes = convert_mel_to_codes(self.local_modules['dvae'], mel)
|
||||||
back_to_mel = self.dvae.decode(mel_codes)[0]
|
back_to_mel = self.local_modules['dvae'].decode(mel_codes)[0]
|
||||||
orig_audio = audio
|
orig_audio = audio
|
||||||
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
||||||
|
|
||||||
|
@ -96,7 +99,7 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
||||||
|
|
||||||
def perform_diffusion_vocoder(self, audio, codes, text, sample_rate=5500):
|
def perform_diffusion_vocoder(self, audio, codes, text, sample_rate=5500):
|
||||||
mel = wav_to_mel(audio)
|
mel = wav_to_mel(audio)
|
||||||
mel_codes = convert_mel_to_codes(self.dvae, mel)
|
mel_codes = convert_mel_to_codes(self.local_modules['dvae'], mel)
|
||||||
text_codes = text_to_sequence(text)
|
text_codes = text_to_sequence(text)
|
||||||
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
||||||
|
|
||||||
|
@ -115,11 +118,12 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
||||||
return gen, real_resampled, sample_rate
|
return gen, real_resampled, sample_rate
|
||||||
|
|
||||||
|
|
||||||
def perform_diffusion_tts9_from_codes(self, audio, codes, text, sample_rate=5500):
|
def perform_diffusion_tts9_mel_from_codes(self, audio, codes, text):
|
||||||
|
SAMPLE_RATE = 24000
|
||||||
mel = wav_to_mel(audio)
|
mel = wav_to_mel(audio)
|
||||||
mel_codes = convert_mel_to_codes(self.dvae, mel)
|
mel_codes = convert_mel_to_codes(self.local_modules['dvae'], mel)
|
||||||
text_codes = text_to_sequence(text)
|
real_resampled = torchaudio.functional.resample(audio, 22050, SAMPLE_RATE).unsqueeze(0)
|
||||||
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
univnet_mel = wav_to_univnet_mel(audio) # to be used for a conditioning input
|
||||||
|
|
||||||
output_size = real_resampled.shape[-1]
|
output_size = real_resampled.shape[-1]
|
||||||
aligned_codes_compression_factor = output_size // mel_codes.shape[-1]
|
aligned_codes_compression_factor = output_size // mel_codes.shape[-1]
|
||||||
|
@ -129,11 +133,12 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
||||||
if padding_needed_for_codes > 0:
|
if padding_needed_for_codes > 0:
|
||||||
mel_codes = F.pad(mel_codes, (0, padding_needed_for_codes))
|
mel_codes = F.pad(mel_codes, (0, padding_needed_for_codes))
|
||||||
output_shape = (1, 1, padded_size)
|
output_shape = (1, 1, padded_size)
|
||||||
gen = self.diffuser.p_sample_loop(self.model, output_shape,
|
gen_mel = self.diffuser.p_sample_loop(self.model, output_shape,
|
||||||
model_kwargs={'tokens': mel_codes,
|
model_kwargs={'aligned_conditioning': mel_codes,
|
||||||
'conditioning_input': audio.unsqueeze(0),
|
'conditioning_input': univnet_mel})
|
||||||
'unaligned_input': torch.tensor(text_codes, device=audio.device).unsqueeze(0)})
|
|
||||||
return gen, real_resampled, sample_rate
|
gen_wav = self.local_modules['vocoder'](gen_mel)
|
||||||
|
return gen_wav, real_resampled, SAMPLE_RATE
|
||||||
|
|
||||||
def load_projector(self):
|
def load_projector(self):
|
||||||
"""
|
"""
|
||||||
|
@ -187,11 +192,11 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
||||||
|
|
||||||
projector = self.load_projector().to(self.env['device'])
|
projector = self.load_projector().to(self.env['device'])
|
||||||
projector.eval()
|
projector.eval()
|
||||||
if hasattr(self, 'dvae'):
|
|
||||||
self.dvae = self.dvae.to(self.env['device'])
|
|
||||||
|
|
||||||
w2v = self.load_w2v().to(self.env['device'])
|
w2v = self.load_w2v().to(self.env['device'])
|
||||||
w2v.eval()
|
w2v.eval()
|
||||||
|
for k, mod in self.local_modules.items():
|
||||||
|
self.local_modules[k] = mod.to(self.env['device'])
|
||||||
|
|
||||||
# Attempt to fix the random state as much as possible. RNG state will be restored before returning.
|
# Attempt to fix the random state as much as possible. RNG state will be restored before returning.
|
||||||
rng_state = torch.get_rng_state()
|
rng_state = torch.get_rng_state()
|
||||||
|
@ -226,10 +231,12 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
||||||
intelligibility_loss = intelligibility_loss / distributed.get_world_size()
|
intelligibility_loss = intelligibility_loss / distributed.get_world_size()
|
||||||
|
|
||||||
self.model.train()
|
self.model.train()
|
||||||
if hasattr(self, 'dvae'):
|
|
||||||
self.dvae = self.dvae.to('cpu')
|
|
||||||
torch.set_rng_state(rng_state)
|
torch.set_rng_state(rng_state)
|
||||||
|
|
||||||
|
# Put modules used for evaluation back into CPU memory.
|
||||||
|
for k, mod in self.local_modules.items():
|
||||||
|
self.local_modules[k] = mod.cpu()
|
||||||
|
|
||||||
return {"frechet_distance": frechet_distance, "intelligibility_loss": intelligibility_loss}
|
return {"frechet_distance": frechet_distance, "intelligibility_loss": intelligibility_loss}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -250,11 +257,12 @@ if __name__ == '__main__':
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from utils.util import load_model_from_config
|
from utils.util import load_model_from_config
|
||||||
|
|
||||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_vocoder_clips_from_dvae_archived_r3_b256_conditioning\\config.yml', 'generator',
|
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts9.yml', 'generator',
|
||||||
also_load_savepoint=False, load_path='X:\\dlas\\experiments\\train_diffusion_vocoder_clips_from_dvae_archived_r3_b256_conditioning\\models\\80800_generator_ema.pth').cuda()
|
also_load_savepoint=False,
|
||||||
|
load_path='X:\\dlas\\experiments\\train_diffusion_tts9\\models\\7500_generator_ema.pth').cuda()
|
||||||
opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 100,
|
opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 100,
|
||||||
'conditioning_free': False, 'conditioning_free_k': 1,
|
'conditioning_free': False, 'conditioning_free_k': 1,
|
||||||
'diffusion_schedule': 'linear', 'diffusion_type': 'original_vocoder'}
|
'diffusion_schedule': 'linear', 'diffusion_type': 'tts9_mel'}
|
||||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 4, 'device': 'cuda', 'opt': {}}
|
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 555, 'device': 'cuda', 'opt': {}}
|
||||||
eval = AudioDiffusionFid(diffusion, opt_eval, env)
|
eval = AudioDiffusionFid(diffusion, opt_eval, env)
|
||||||
print(eval.perform_eval())
|
print(eval.perform_eval())
|
||||||
|
|
Loading…
Reference in New Issue
Block a user