This commit is contained in:
James Betker 2022-02-27 14:49:11 -07:00
parent ba155e4e2f
commit ac920798bb
3 changed files with 6 additions and 2 deletions

View File

@ -62,6 +62,10 @@ def tacotron_symbols():
return list(_symbol_to_id.keys()) return list(_symbol_to_id.keys())
def tacotron_symbol_mapping():
return _symbol_to_id.copy()
def _clean_text(text, cleaner_names): def _clean_text(text, cleaner_names):
for name in cleaner_names: for name in cleaner_names:
cleaner = getattr(cleaners, name) cleaner = getattr(cleaners, name)

View File

@ -317,7 +317,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_wav2vec_mass_large_/train_wav2vec_mass.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/debug_diffusion_tts7.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()

View File

@ -131,7 +131,7 @@ if __name__ == '__main__':
#'kenlm_path': 'Y:\\bookscorpus-5gram\\5gram.bin', #'kenlm_path': 'Y:\\bookscorpus-5gram\\5gram.bin',
} }
model = Wav2VecWrapper(vocab_size=148, basis_model='facebook/wav2vec2-large-robust-ft-libri-960h', freeze_transformer=True, checkpointing_enabled=False) model = Wav2VecWrapper(vocab_size=148, basis_model='facebook/wav2vec2-large-robust-ft-libri-960h', freeze_transformer=True, checkpointing_enabled=False)
weights = torch.load('D:\\dlas\\experiments\\train_wav2vec_mass_large2\\models\\17000_wav2vec.pth') weights = torch.load('D:\\dlas\\experiments\\train_wav2vec_mass_large2\\models\\22500_wav2vec.pth')
model.load_state_dict(weights) model.load_state_dict(weights)
model = model.cuda() model = model.cuda()
eval = WerEvaluator(model, opt_eval, env) eval = WerEvaluator(model, opt_eval, env)