added kludgy eval only so I don't have to start training, type eval, stop training, then delete the logs for that session
This commit is contained in:
parent
ddbacde0d1
commit
d760924719
|
@ -17,10 +17,12 @@ import traceback
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
import argparse
|
||||||
|
|
||||||
mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cpu")
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu")
|
||||||
|
|
||||||
def train_feeder(engine, batch):
|
def train_feeder(engine, batch):
|
||||||
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
||||||
engine(
|
engine(
|
||||||
|
@ -153,6 +155,10 @@ def run_eval(engines, eval_name, dl):
|
||||||
|
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
|
parser = argparse.ArgumentParser("VALL-E TTS")
|
||||||
|
parser.add_argument("--eval", action="store_true")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
setup_logging(cfg.log_dir)
|
setup_logging(cfg.log_dir)
|
||||||
|
|
||||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||||
|
@ -170,6 +176,9 @@ def train():
|
||||||
|
|
||||||
qnt.unload_model()
|
qnt.unload_model()
|
||||||
|
|
||||||
|
if args.eval:
|
||||||
|
return eval_fn(engines=trainer.load_engines())
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if cfg.trainer.load_webui:
|
if cfg.trainer.load_webui:
|
||||||
from .webui import start
|
from .webui import start
|
||||||
|
|
Loading…
Reference in New Issue
Block a user