From 736c07728262dbec7d7ea49a6579364b034b7bc9 Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 20 Aug 2023 13:42:18 -0500 Subject: [PATCH] ops --- vall_e/config.py | 1 + vall_e/engines/base.py | 8 ++++---- vall_e/export.py | 5 ++++- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index f2cf356..9205814 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -379,6 +379,7 @@ class Trainer: load_state_dict: bool = False load_states: bool = True strict_loading: bool = True + load_module_only: bool = False restart_step_count: bool = False aggressive_optimizations: bool = False diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index f820657..41b2bc6 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -261,7 +261,7 @@ class Engines(dict[str, Engine]): p.unlink() d.rmdir() - def load_checkpoint(self, tag=None, module_only=False): + def load_checkpoint(self, tag=None): if not tag: tag = cfg.trainer.load_tag @@ -271,9 +271,9 @@ class Engines(dict[str, Engine]): tag=tag, load_dir=load_dir, load_module_strict=cfg.trainer.strict_loading, - load_optimizer_states=False if module_only else cfg.trainer.load_states, - load_lr_scheduler_states=False if module_only else cfg.trainer.load_states, - load_module_only=module_only, + load_optimizer_states=False if cfg.trainer.load_module_only else cfg.trainer.load_states, + load_lr_scheduler_states=False if cfg.trainer.load_module_only else cfg.trainer.load_states, + load_module_only=cfg.trainer.load_module_only, ) if cfg.trainer.restart_step_count: engine.global_steps = 0 diff --git a/vall_e/export.py b/vall_e/export.py index 464803b..f7778c6 100755 --- a/vall_e/export.py +++ b/vall_e/export.py @@ -8,9 +8,12 @@ from .config import cfg def main(): parser = argparse.ArgumentParser("Save trained model to path.") - #parser.add_argument("--yaml", type=Path, default=None) + parser.add_argument("--module-only", action='store_true') args = parser.parse_args() + if args.module_only: + cfg.trainer.load_module_only = True + engines = load_engines() engines.export(userdata={"symmap": get_phone_symmap()})