ops
This commit is contained in:
parent
b105f6211e
commit
736c077282
|
@ -379,6 +379,7 @@ class Trainer:
|
|||
load_state_dict: bool = False
|
||||
load_states: bool = True
|
||||
strict_loading: bool = True
|
||||
load_module_only: bool = False
|
||||
restart_step_count: bool = False
|
||||
|
||||
aggressive_optimizations: bool = False
|
||||
|
|
|
@ -261,7 +261,7 @@ class Engines(dict[str, Engine]):
|
|||
p.unlink()
|
||||
d.rmdir()
|
||||
|
||||
def load_checkpoint(self, tag=None, module_only=False):
|
||||
def load_checkpoint(self, tag=None):
|
||||
if not tag:
|
||||
tag = cfg.trainer.load_tag
|
||||
|
||||
|
@ -271,9 +271,9 @@ class Engines(dict[str, Engine]):
|
|||
tag=tag,
|
||||
load_dir=load_dir,
|
||||
load_module_strict=cfg.trainer.strict_loading,
|
||||
load_optimizer_states=False if module_only else cfg.trainer.load_states,
|
||||
load_lr_scheduler_states=False if module_only else cfg.trainer.load_states,
|
||||
load_module_only=module_only,
|
||||
load_optimizer_states=False if cfg.trainer.load_module_only else cfg.trainer.load_states,
|
||||
load_lr_scheduler_states=False if cfg.trainer.load_module_only else cfg.trainer.load_states,
|
||||
load_module_only=cfg.trainer.load_module_only,
|
||||
)
|
||||
if cfg.trainer.restart_step_count:
|
||||
engine.global_steps = 0
|
||||
|
|
|
@ -8,9 +8,12 @@ from .config import cfg
|
|||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser("Save trained model to path.")
|
||||
#parser.add_argument("--yaml", type=Path, default=None)
|
||||
parser.add_argument("--module-only", action='store_true')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.module_only:
|
||||
cfg.trainer.load_module_only = True
|
||||
|
||||
engines = load_engines()
|
||||
engines.export(userdata={"symmap": get_phone_symmap()})
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user