fix mel normalization
This commit is contained in:
parent
8437bb0c53
commit
54202aa099
|
@ -22,10 +22,10 @@ def load_speech_dvae():
|
|||
|
||||
def load_univnet_vocoder():
|
||||
model = UnivNetGenerator()
|
||||
sd = torch.load('univnet_c32_pretrained_libri.pt')
|
||||
model.load_state_dict(sd)
|
||||
sd = torch.load('../experiments/univnet_c32_pretrained_libri.pt')
|
||||
model.load_state_dict(sd['model_g'])
|
||||
model = model.cpu()
|
||||
model.eval()
|
||||
model.eval(inference=True)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -36,12 +36,13 @@ def wav_to_mel(wav, mel_norms_file='../experiments/clips_mel_norms.pth'):
|
|||
return TorchMelSpectrogramInjector({'in': 'wav', 'out': 'mel', 'mel_norm_file': mel_norms_file},{})({'wav': wav})['mel']
|
||||
|
||||
|
||||
def wav_to_univnet_mel(wav):
|
||||
def wav_to_univnet_mel(wav, do_normalization=False):
|
||||
"""
|
||||
Converts an audio clip into a MEL tensor that the univnet vocoder knows how to decode.
|
||||
"""
|
||||
return MelSpectrogramInjector({'in': 'wav', 'out': 'mel', 'sampling_rate': 24000,
|
||||
'n_mel_channels': 100, 'mel_fmax': 12000},{})({'wav': wav})['mel']
|
||||
'n_mel_channels': 100, 'mel_fmax': 12000,
|
||||
'do_normalizattion': do_normalization},{})({'wav': wav})['mel']
|
||||
|
||||
|
||||
def convert_mel_to_codes(dvae_model, mel):
|
||||
|
|
|
@ -318,7 +318,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_noisy_audio_clips_classifier.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_diffusion_tts9_mel.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
|
|
|
@ -16,6 +16,7 @@ from models.clip.mel_text_clip import MelTextCLIP
|
|||
from models.audio.tts.tacotron2 import text_to_sequence
|
||||
from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel, load_speech_dvae, \
|
||||
convert_mel_to_codes, load_univnet_vocoder, wav_to_univnet_mel
|
||||
from trainer.injectors.audio_injectors import denormalize_tacotron_mel
|
||||
from utils.util import ceil_multiple, opt_get
|
||||
|
||||
|
||||
|
@ -124,7 +125,7 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
mel = wav_to_mel(audio)
|
||||
mel_codes = convert_mel_to_codes(self.local_modules['dvae'], mel)
|
||||
real_resampled = torchaudio.functional.resample(audio, 22050, SAMPLE_RATE).unsqueeze(0)
|
||||
univnet_mel = wav_to_univnet_mel(audio) # to be used for a conditioning input
|
||||
univnet_mel = wav_to_univnet_mel(audio, mel_norms_file=None) # to be used for a conditioning input
|
||||
|
||||
output_size = univnet_mel.shape[-1]
|
||||
aligned_codes_compression_factor = output_size // mel_codes.shape[-1]
|
||||
|
@ -138,7 +139,7 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
model_kwargs={'aligned_conditioning': mel_codes,
|
||||
'conditioning_input': univnet_mel})
|
||||
# denormalize mel
|
||||
gen_mel = ((gen_mel+1)/2)*(self.mel_max-self.mel_min)+self.mel_min
|
||||
gen_mel = denormalize_tacotron_mel(gen_mel)
|
||||
|
||||
gen_wav = self.local_modules['vocoder'].inference(gen_mel)
|
||||
real_dec = self.local_modules['vocoder'].inference(univnet_mel)
|
||||
|
@ -263,7 +264,7 @@ if __name__ == '__main__':
|
|||
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts9_mel.yml', 'generator',
|
||||
also_load_savepoint=False,
|
||||
load_path='X:\\dlas\\experiments\\train_diffusion_tts9_mel\\models\\10000_generator_ema.pth').cuda()
|
||||
load_path='X:\\dlas\\experiments\\train_diffusion_tts9_mel\\models\\5000_generator_ema.pth').cuda()
|
||||
opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 100,
|
||||
'conditioning_free': False, 'conditioning_free_k': 1,
|
||||
'diffusion_schedule': 'linear', 'diffusion_type': 'tts9_mel'}
|
||||
|
|
|
@ -7,6 +7,14 @@ import torchaudio
|
|||
from trainer.inject import Injector
|
||||
from utils.util import opt_get, load_model_from_config
|
||||
|
||||
TACOTRON_MEL_MAX = 2.3143386840820312
|
||||
TACOTRON_MEL_MIN = -11.512925148010254
|
||||
|
||||
def normalize_tacotron_mel(mel):
|
||||
return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1
|
||||
|
||||
def denormalize_tacotron_mel(norm_mel):
|
||||
return ((norm_mel+1)/2)*(TACOTRON_MEL_MAX-TACOTRON_MEL_MIN)+TACOTRON_MEL_MIN
|
||||
|
||||
class MelSpectrogramInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
|
@ -21,14 +29,7 @@ class MelSpectrogramInjector(Injector):
|
|||
mel_fmax = opt_get(opt, ['mel_fmax'], 8000)
|
||||
sampling_rate = opt_get(opt, ['sampling_rate'], 22050)
|
||||
self.stft = TacotronSTFT(filter_length, hop_length, win_length, n_mel_channels, sampling_rate, mel_fmin, mel_fmax)
|
||||
self.mel_norm_file = opt_get(opt, ['mel_norm_file'], None)
|
||||
if self.mel_norm_file is not None:
|
||||
# Note that this format is different from TorchMelSpectrogramInjector.
|
||||
mel_means, mel_max, mel_min, mel_stds, mel_vars = torch.load(self.mel_norm_file)
|
||||
self.mel_max = mel_max
|
||||
self.mel_min = mel_min
|
||||
else:
|
||||
self.mel_max = None
|
||||
self.do_normalization = opt_get(opt, ['do_normalization'], None) # This is different from the TorchMelSpectrogramInjector. This just normalizes to the range [-1,1]
|
||||
|
||||
def forward(self, state):
|
||||
inp = state[self.input]
|
||||
|
@ -37,8 +38,8 @@ class MelSpectrogramInjector(Injector):
|
|||
assert len(inp.shape) == 2
|
||||
self.stft = self.stft.to(inp.device)
|
||||
mel = self.stft.mel_spectrogram(inp)
|
||||
if self.mel_max is not None:
|
||||
mel = 2 * ((mel - self.mel_min) / (self.mel_max - self.mel_min)) - 1
|
||||
if self.do_normalization:
|
||||
mel = normalize_tacotron_mel(mel)
|
||||
return {self.output: mel}
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user