diff --git a/vall_e/train.py b/vall_e/train.py index 70506ba..a93a6aa 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -17,10 +17,12 @@ import traceback from collections import defaultdict from tqdm import tqdm +import argparse -mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cpu") _logger = logging.getLogger(__name__) +mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu") + def train_feeder(engine, batch): with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp): engine( @@ -153,6 +155,10 @@ def run_eval(engines, eval_name, dl): def train(): + parser = argparse.ArgumentParser("VALL-E TTS") + parser.add_argument("--eval", action="store_true") + args = parser.parse_args() + setup_logging(cfg.log_dir) train_dl, subtrain_dl, val_dl = create_train_val_dataloader() @@ -170,6 +176,9 @@ def train(): qnt.unload_model() + if args.eval: + return eval_fn(engines=trainer.load_engines()) + """ if cfg.trainer.load_webui: from .webui import start