added ability to export weights mid-training to avoid CBT to yank the weights while the training script is running

This commit is contained in:
mrq 2023-08-20 13:39:58 -05:00
parent fc576010ce
commit b105f6211e
3 changed files with 39 additions and 15 deletions

View File

@ -200,6 +200,9 @@ class Engines(dict[str, Engine]):
self._global_step = 0
self._micro_step = 0
for name, engine in self.items():
engine.name = name
@property
def global_step(self):
return self._global_step
@ -218,6 +221,18 @@ class Engines(dict[str, Engine]):
for engine in self.values():
engine.dispatch_attribute(*args, **kwargs)
def export(self, userdata={}):
for name, engine in self.items():
outpath = cfg.ckpt_dir / name / "fp32.pth"
state_dict = {
"global_step": engine.global_step,
"micro_step": engine.micro_step,
'module': engine.module.state_dict(),
}
state_dict.update(userdata)
torch.save(state_dict, outpath)
print(f"Exported {name} to {outpath}")
def save_checkpoint(self, tag=None):
if not tag:
tag = cfg.trainer.save_tag
@ -246,7 +261,7 @@ class Engines(dict[str, Engine]):
p.unlink()
d.rmdir()
def load_checkpoint(self, tag=None):
def load_checkpoint(self, tag=None, module_only=False):
if not tag:
tag = cfg.trainer.load_tag
@ -256,8 +271,9 @@ class Engines(dict[str, Engine]):
tag=tag,
load_dir=load_dir,
load_module_strict=cfg.trainer.strict_loading,
load_optimizer_states=cfg.trainer.load_states,
load_lr_scheduler_states=cfg.trainer.load_states,
load_optimizer_states=False if module_only else cfg.trainer.load_states,
load_lr_scheduler_states=False if module_only else cfg.trainer.load_states,
load_module_only=module_only,
)
if cfg.trainer.restart_step_count:
engine.global_steps = 0

View File

@ -12,16 +12,7 @@ def main():
args = parser.parse_args()
engines = load_engines()
for name, engine in engines.items():
outpath = cfg.ckpt_dir / name / "fp32.pth"
torch.save({
"global_step": engine.global_step,
"micro_step": engine.micro_step,
'module': engine.module.to('cpu', dtype=torch.float32).state_dict(),
#'optimizer': engine.optimizer.state_dict(),
'symmap': get_phone_symmap(),
}, outpath)
print(f"Exported {name} to {outpath}")
engines.export(userdata={"symmap": get_phone_symmap()})
if __name__ == "__main__":
main()

View File

@ -33,6 +33,7 @@ from ..models import get_models
from .utils import to_device, do_gc
from ..utils import wrapper as ml
from ..data import get_phone_symmap # should decouple from this trainer script
_logger = logging.getLogger(__name__)
_engines: Engines
@ -69,6 +70,9 @@ def load_engines():
optimizer = None
lr_scheduler = None
# yuck, should instead check be optimier == "adamw" and backend != "deepspeed"
# and then have ds_cfg pass in the config flag to use pytorch adamw
# I genuinely cannot validate if this ever actually gets used in DeepSpeed
if cfg.hyperparameters.optimizer.lower() == "adamw-torch":
optimizer = ml.AdamW(
model.parameters(),
@ -86,6 +90,9 @@ def load_engines():
if "module" in state:
state = state["module"]
# should decouple the following from this trainer script
# probably with passing a fun that defaults to a lambda x: x deal
# extend the proms_emb if we ever touch the n_prom_levels or n_prom_tokens (from adding tasks)
if model.proms_emb.weight.shape[0] > state['proms_emb.weight'].shape[0] or model.proms_emb.weight.shape[1] > state['proms_emb.weight'].shape[1]:
o_prom_levels, o_prom_tokens, d_model = state['proms_emb.weight'].shape
@ -301,8 +308,18 @@ def train(
if "lr" in command:
rate = float(command.split(" ")[-1])
try:
engines.set_lr(rate)
print("Updating LR to:", rate)
except Exception as e:
print("Failed to set LR rate to:", rate, str(e))
if "export" in command:
engines.save_checkpoint()
last_save_step = engines.global_step
if is_global_leader():
engines.export(userdata={"symmap": get_phone_symmap()})
save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency