forked from mrq/DL-Art-School
fix audio_diffusion_fid for autoregressive latent inputs
This commit is contained in:
parent
8ea5c307fb
commit
f2c172291f
|
@ -327,7 +327,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_clvp.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_diffusion_tts_mel_flat_autoregressive_inputs.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
|
|
|
@ -12,6 +12,7 @@ import numpy as np
|
|||
import trainer.eval.evaluator as evaluator
|
||||
from data.audio.paired_voice_audio_dataset import load_tsv_aligned_codes
|
||||
from data.audio.unsupervised_audio_dataset import load_audio
|
||||
from data.audio.voice_tokenizer import VoiceBpeTokenizer
|
||||
from models.clip.mel_text_clip import MelTextCLIP
|
||||
from models.audio.tts.tacotron2 import text_to_sequence
|
||||
from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel, load_speech_dvae, \
|
||||
|
@ -23,6 +24,9 @@ from utils.util import ceil_multiple, opt_get
|
|||
class AudioDiffusionFid(evaluator.Evaluator):
|
||||
"""
|
||||
Evaluator produces generate from a diffusion model, then uses a CLIP model to judge the similarity between text & speech.
|
||||
|
||||
This evaluator is kind of a mess. It has been repeatedly modified to work with several different model types, which
|
||||
means it is bloated beyond belief. I would not recommend attempting to understand what is going on here.
|
||||
"""
|
||||
def __init__(self, model, opt_eval, env):
|
||||
super().__init__(model, opt_eval, env, uses_all_ddp=True)
|
||||
|
@ -53,11 +57,22 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
elif mode == 'vocoder':
|
||||
self.local_modules['dvae'] = load_speech_dvae().cpu()
|
||||
self.diffusion_fn = self.perform_diffusion_vocoder
|
||||
elif mode == 'tts9_mel':
|
||||
elif 'tts9_mel' in mode:
|
||||
mel_means, self.mel_max, self.mel_min, mel_stds, mel_vars = torch.load('../experiments/univnet_mel_norms.pth')
|
||||
self.bpe_tokenizer = VoiceBpeTokenizer('../experiments/bpe_lowercase_asr_256.json')
|
||||
self.local_modules['dvae'] = load_speech_dvae().cpu()
|
||||
self.local_modules['vocoder'] = load_univnet_vocoder().cpu()
|
||||
self.diffusion_fn = self.perform_diffusion_tts9_mel_from_codes
|
||||
if mode == 'tts9_mel_autoin':
|
||||
self.local_modules['autoregressive'] = load_model_from_config("../experiments/train_gpt_tts_unified.yml",
|
||||
model_name='gpt',
|
||||
also_load_savepoint=False,
|
||||
load_path='../experiments/unified_large_diverse_basis.pth',
|
||||
device=torch.device('cpu')).cuda().eval()
|
||||
self.tts9_codegen = self.tts9_get_autoregressive_codes
|
||||
else:
|
||||
self.tts9_codegen = self.tts9_get_dvae_codes
|
||||
|
||||
|
||||
def perform_diffusion_tts(self, audio, codes, text, sample_rate=5500):
|
||||
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
||||
|
@ -119,25 +134,28 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
'unaligned_input': torch.tensor(text_codes, device=audio.device).unsqueeze(0)})
|
||||
return gen, real_resampled, sample_rate
|
||||
|
||||
def tts9_get_autoregressive_codes(self, mel, text):
|
||||
mel_codes = convert_mel_to_codes(self.local_modules['dvae'], mel)
|
||||
text_codes = torch.LongTensor(self.bpe_tokenizer.encode(text)).unsqueeze(0).to(mel.device)
|
||||
cond_inputs = mel.unsqueeze(1)
|
||||
auto_latents = self.local_modules['autoregressive'].forward(cond_inputs, text_codes,
|
||||
torch.tensor([text_codes.shape[-1]], device=mel.device),
|
||||
mel_codes,
|
||||
torch.tensor([mel_codes.shape[-1]], device=mel.device),
|
||||
text_first=True, raw_mels=None, return_latent=True,
|
||||
clip_inputs=False)
|
||||
return auto_latents
|
||||
|
||||
def tts9_get_dvae_codes(self, mel, text):
|
||||
return convert_mel_to_codes(self.local_modules['dvae'], mel)
|
||||
|
||||
def perform_diffusion_tts9_mel_from_codes(self, audio, codes, text):
|
||||
SAMPLE_RATE = 24000
|
||||
mel = wav_to_mel(audio)
|
||||
mel_codes = convert_mel_to_codes(self.local_modules['dvae'], mel)
|
||||
mel_codes = self.tts9_codegen(mel, text)
|
||||
real_resampled = torchaudio.functional.resample(audio, 22050, SAMPLE_RATE).unsqueeze(0)
|
||||
univnet_mel = wav_to_univnet_mel(real_resampled, do_normalization=False) # to be used for a conditioning input, but also guides output shape.
|
||||
|
||||
output_size = univnet_mel.shape[-1]
|
||||
aligned_codes_compression_factor = output_size // mel_codes.shape[-1]
|
||||
if hasattr(self.model, 'alignment_size'):
|
||||
padded_size = ceil_multiple(output_size, self.model.alignment_size)
|
||||
padding_added = padded_size - output_size
|
||||
padding_needed_for_codes = padding_added // aligned_codes_compression_factor
|
||||
if padding_needed_for_codes > 0:
|
||||
mel_codes = F.pad(mel_codes, (0, padding_needed_for_codes))
|
||||
output_shape = (1, 100, padded_size)
|
||||
else:
|
||||
output_shape = univnet_mel.shape
|
||||
output_shape = univnet_mel.shape
|
||||
gen_mel = self.diffuser.p_sample_loop(self.model, output_shape,
|
||||
model_kwargs={'aligned_conditioning': mel_codes,
|
||||
'conditioning_input': univnet_mel})
|
||||
|
@ -265,12 +283,12 @@ if __name__ == '__main__':
|
|||
from utils.util import load_model_from_config
|
||||
# 34k; no conditioning_free: {'frechet_distance': tensor(1.4559, device='cuda:0', dtype=torch.float64), 'intelligibility_loss': tensor(151.9112, device='cuda:0')}
|
||||
# 34k; conditioning_free: {'frechet_distance': tensor(1.4059, device='cuda:0', dtype=torch.float64), 'intelligibility_loss': tensor(118.3377, device='cuda:0')}
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts_mel_flat.yml', 'generator',
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts_mel_flat_autoregressive_inputs.yml', 'generator',
|
||||
also_load_savepoint=False,
|
||||
load_path='X:\\dlas\\experiments\\train_diffusion_tts_mel_flat0\\models\\34000_generator_ema.pth').cuda()
|
||||
load_path='X:\\dlas\\experiments\\tts_flat_autoregressive_inputs_r2_initial\\models\\500_generator.pth').cuda()
|
||||
opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 100,
|
||||
'conditioning_free': True, 'conditioning_free_k': 1,
|
||||
'diffusion_schedule': 'linear', 'diffusion_type': 'tts9_mel'}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 560, 'device': 'cuda', 'opt': {}}
|
||||
'diffusion_schedule': 'linear', 'diffusion_type': 'tts9_mel_autoin'}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 561, 'device': 'cuda', 'opt': {}}
|
||||
eval = AudioDiffusionFid(diffusion, opt_eval, env)
|
||||
print(eval.perform_eval())
|
||||
|
|
Loading…
Reference in New Issue
Block a user