This commit is contained in:
mrq 2023-08-20 13:42:18 -05:00
parent b105f6211e
commit 736c077282
3 changed files with 9 additions and 5 deletions

View File

@ -379,6 +379,7 @@ class Trainer:
load_state_dict: bool = False
load_states: bool = True
strict_loading: bool = True
load_module_only: bool = False
restart_step_count: bool = False
aggressive_optimizations: bool = False

View File

@ -261,7 +261,7 @@ class Engines(dict[str, Engine]):
p.unlink()
d.rmdir()
def load_checkpoint(self, tag=None, module_only=False):
def load_checkpoint(self, tag=None):
if not tag:
tag = cfg.trainer.load_tag
@ -271,9 +271,9 @@ class Engines(dict[str, Engine]):
tag=tag,
load_dir=load_dir,
load_module_strict=cfg.trainer.strict_loading,
load_optimizer_states=False if module_only else cfg.trainer.load_states,
load_lr_scheduler_states=False if module_only else cfg.trainer.load_states,
load_module_only=module_only,
load_optimizer_states=False if cfg.trainer.load_module_only else cfg.trainer.load_states,
load_lr_scheduler_states=False if cfg.trainer.load_module_only else cfg.trainer.load_states,
load_module_only=cfg.trainer.load_module_only,
)
if cfg.trainer.restart_step_count:
engine.global_steps = 0

View File

@ -8,9 +8,12 @@ from .config import cfg
def main():
parser = argparse.ArgumentParser("Save trained model to path.")
#parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--module-only", action='store_true')
args = parser.parse_args()
if args.module_only:
cfg.trainer.load_module_only = True
engines = load_engines()
engines.export(userdata={"symmap": get_phone_symmap()})