added kludgy eval only so I don't have to start training, type eval, stop training, then delete the logs for that session

This commit is contained in:
mrq 2024-05-25 17:39:51 -05:00
parent ddbacde0d1
commit d760924719

View File

@ -17,10 +17,12 @@ import traceback
from collections import defaultdict from collections import defaultdict
from tqdm import tqdm from tqdm import tqdm
import argparse
mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cpu")
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu")
def train_feeder(engine, batch): def train_feeder(engine, batch):
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp): with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
engine( engine(
@ -153,6 +155,10 @@ def run_eval(engines, eval_name, dl):
def train(): def train():
parser = argparse.ArgumentParser("VALL-E TTS")
parser.add_argument("--eval", action="store_true")
args = parser.parse_args()
setup_logging(cfg.log_dir) setup_logging(cfg.log_dir)
train_dl, subtrain_dl, val_dl = create_train_val_dataloader() train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
@ -170,6 +176,9 @@ def train():
qnt.unload_model() qnt.unload_model()
if args.eval:
return eval_fn(engines=trainer.load_engines())
""" """
if cfg.trainer.load_webui: if cfg.trainer.load_webui:
from .webui import start from .webui import start