added kludgy eval only so I don't have to start training, type eval, stop training, then delete the logs for that session
This commit is contained in:
parent
ddbacde0d1
commit
d760924719
|
@ -17,10 +17,12 @@ import traceback
|
|||
from collections import defaultdict
|
||||
|
||||
from tqdm import tqdm
|
||||
import argparse
|
||||
|
||||
mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cpu")
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu")
|
||||
|
||||
def train_feeder(engine, batch):
|
||||
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
||||
engine(
|
||||
|
@ -153,6 +155,10 @@ def run_eval(engines, eval_name, dl):
|
|||
|
||||
|
||||
def train():
|
||||
parser = argparse.ArgumentParser("VALL-E TTS")
|
||||
parser.add_argument("--eval", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
setup_logging(cfg.log_dir)
|
||||
|
||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||
|
@ -170,6 +176,9 @@ def train():
|
|||
|
||||
qnt.unload_model()
|
||||
|
||||
if args.eval:
|
||||
return eval_fn(engines=trainer.load_engines())
|
||||
|
||||
"""
|
||||
if cfg.trainer.load_webui:
|
||||
from .webui import start
|
||||
|
|
Loading…
Reference in New Issue
Block a user