forked from mrq/DL-Art-School
misc
This commit is contained in:
parent
ba155e4e2f
commit
ac920798bb
|
@ -62,6 +62,10 @@ def tacotron_symbols():
|
||||||
return list(_symbol_to_id.keys())
|
return list(_symbol_to_id.keys())
|
||||||
|
|
||||||
|
|
||||||
|
def tacotron_symbol_mapping():
|
||||||
|
return _symbol_to_id.copy()
|
||||||
|
|
||||||
|
|
||||||
def _clean_text(text, cleaner_names):
|
def _clean_text(text, cleaner_names):
|
||||||
for name in cleaner_names:
|
for name in cleaner_names:
|
||||||
cleaner = getattr(cleaners, name)
|
cleaner = getattr(cleaners, name)
|
||||||
|
|
|
@ -317,7 +317,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_wav2vec_mass_large_/train_wav2vec_mass.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/debug_diffusion_tts7.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -131,7 +131,7 @@ if __name__ == '__main__':
|
||||||
#'kenlm_path': 'Y:\\bookscorpus-5gram\\5gram.bin',
|
#'kenlm_path': 'Y:\\bookscorpus-5gram\\5gram.bin',
|
||||||
}
|
}
|
||||||
model = Wav2VecWrapper(vocab_size=148, basis_model='facebook/wav2vec2-large-robust-ft-libri-960h', freeze_transformer=True, checkpointing_enabled=False)
|
model = Wav2VecWrapper(vocab_size=148, basis_model='facebook/wav2vec2-large-robust-ft-libri-960h', freeze_transformer=True, checkpointing_enabled=False)
|
||||||
weights = torch.load('D:\\dlas\\experiments\\train_wav2vec_mass_large2\\models\\17000_wav2vec.pth')
|
weights = torch.load('D:\\dlas\\experiments\\train_wav2vec_mass_large2\\models\\22500_wav2vec.pth')
|
||||||
model.load_state_dict(weights)
|
model.load_state_dict(weights)
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
eval = WerEvaluator(model, opt_eval, env)
|
eval = WerEvaluator(model, opt_eval, env)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user