added ability to export weights mid-training to avoid CBT to yank the weights while the training script is running
This commit is contained in:
parent
fc576010ce
commit
b105f6211e
|
@ -200,6 +200,9 @@ class Engines(dict[str, Engine]):
|
|||
self._global_step = 0
|
||||
self._micro_step = 0
|
||||
|
||||
for name, engine in self.items():
|
||||
engine.name = name
|
||||
|
||||
@property
|
||||
def global_step(self):
|
||||
return self._global_step
|
||||
|
@ -218,6 +221,18 @@ class Engines(dict[str, Engine]):
|
|||
for engine in self.values():
|
||||
engine.dispatch_attribute(*args, **kwargs)
|
||||
|
||||
def export(self, userdata={}):
|
||||
for name, engine in self.items():
|
||||
outpath = cfg.ckpt_dir / name / "fp32.pth"
|
||||
state_dict = {
|
||||
"global_step": engine.global_step,
|
||||
"micro_step": engine.micro_step,
|
||||
'module': engine.module.state_dict(),
|
||||
}
|
||||
state_dict.update(userdata)
|
||||
torch.save(state_dict, outpath)
|
||||
print(f"Exported {name} to {outpath}")
|
||||
|
||||
def save_checkpoint(self, tag=None):
|
||||
if not tag:
|
||||
tag = cfg.trainer.save_tag
|
||||
|
@ -246,7 +261,7 @@ class Engines(dict[str, Engine]):
|
|||
p.unlink()
|
||||
d.rmdir()
|
||||
|
||||
def load_checkpoint(self, tag=None):
|
||||
def load_checkpoint(self, tag=None, module_only=False):
|
||||
if not tag:
|
||||
tag = cfg.trainer.load_tag
|
||||
|
||||
|
@ -256,8 +271,9 @@ class Engines(dict[str, Engine]):
|
|||
tag=tag,
|
||||
load_dir=load_dir,
|
||||
load_module_strict=cfg.trainer.strict_loading,
|
||||
load_optimizer_states=cfg.trainer.load_states,
|
||||
load_lr_scheduler_states=cfg.trainer.load_states,
|
||||
load_optimizer_states=False if module_only else cfg.trainer.load_states,
|
||||
load_lr_scheduler_states=False if module_only else cfg.trainer.load_states,
|
||||
load_module_only=module_only,
|
||||
)
|
||||
if cfg.trainer.restart_step_count:
|
||||
engine.global_steps = 0
|
||||
|
|
|
@ -12,16 +12,7 @@ def main():
|
|||
args = parser.parse_args()
|
||||
|
||||
engines = load_engines()
|
||||
for name, engine in engines.items():
|
||||
outpath = cfg.ckpt_dir / name / "fp32.pth"
|
||||
torch.save({
|
||||
"global_step": engine.global_step,
|
||||
"micro_step": engine.micro_step,
|
||||
'module': engine.module.to('cpu', dtype=torch.float32).state_dict(),
|
||||
#'optimizer': engine.optimizer.state_dict(),
|
||||
'symmap': get_phone_symmap(),
|
||||
}, outpath)
|
||||
print(f"Exported {name} to {outpath}")
|
||||
engines.export(userdata={"symmap": get_phone_symmap()})
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -33,6 +33,7 @@ from ..models import get_models
|
|||
|
||||
from .utils import to_device, do_gc
|
||||
from ..utils import wrapper as ml
|
||||
from ..data import get_phone_symmap # should decouple from this trainer script
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
_engines: Engines
|
||||
|
@ -69,6 +70,9 @@ def load_engines():
|
|||
optimizer = None
|
||||
lr_scheduler = None
|
||||
|
||||
# yuck, should instead check be optimier == "adamw" and backend != "deepspeed"
|
||||
# and then have ds_cfg pass in the config flag to use pytorch adamw
|
||||
# I genuinely cannot validate if this ever actually gets used in DeepSpeed
|
||||
if cfg.hyperparameters.optimizer.lower() == "adamw-torch":
|
||||
optimizer = ml.AdamW(
|
||||
model.parameters(),
|
||||
|
@ -86,6 +90,9 @@ def load_engines():
|
|||
if "module" in state:
|
||||
state = state["module"]
|
||||
|
||||
# should decouple the following from this trainer script
|
||||
# probably with passing a fun that defaults to a lambda x: x deal
|
||||
|
||||
# extend the proms_emb if we ever touch the n_prom_levels or n_prom_tokens (from adding tasks)
|
||||
if model.proms_emb.weight.shape[0] > state['proms_emb.weight'].shape[0] or model.proms_emb.weight.shape[1] > state['proms_emb.weight'].shape[1]:
|
||||
o_prom_levels, o_prom_tokens, d_model = state['proms_emb.weight'].shape
|
||||
|
@ -301,8 +308,18 @@ def train(
|
|||
|
||||
if "lr" in command:
|
||||
rate = float(command.split(" ")[-1])
|
||||
try:
|
||||
engines.set_lr(rate)
|
||||
print("Updating LR to:", rate)
|
||||
except Exception as e:
|
||||
print("Failed to set LR rate to:", rate, str(e))
|
||||
|
||||
if "export" in command:
|
||||
engines.save_checkpoint()
|
||||
last_save_step = engines.global_step
|
||||
|
||||
if is_global_leader():
|
||||
engines.export(userdata={"symmap": get_phone_symmap()})
|
||||
|
||||
save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user