added ability to export weights mid-training to avoid CBT to yank the weights while the training script is running
This commit is contained in:
parent
fc576010ce
commit
b105f6211e
|
@ -200,6 +200,9 @@ class Engines(dict[str, Engine]):
|
||||||
self._global_step = 0
|
self._global_step = 0
|
||||||
self._micro_step = 0
|
self._micro_step = 0
|
||||||
|
|
||||||
|
for name, engine in self.items():
|
||||||
|
engine.name = name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def global_step(self):
|
def global_step(self):
|
||||||
return self._global_step
|
return self._global_step
|
||||||
|
@ -218,6 +221,18 @@ class Engines(dict[str, Engine]):
|
||||||
for engine in self.values():
|
for engine in self.values():
|
||||||
engine.dispatch_attribute(*args, **kwargs)
|
engine.dispatch_attribute(*args, **kwargs)
|
||||||
|
|
||||||
|
def export(self, userdata={}):
|
||||||
|
for name, engine in self.items():
|
||||||
|
outpath = cfg.ckpt_dir / name / "fp32.pth"
|
||||||
|
state_dict = {
|
||||||
|
"global_step": engine.global_step,
|
||||||
|
"micro_step": engine.micro_step,
|
||||||
|
'module': engine.module.state_dict(),
|
||||||
|
}
|
||||||
|
state_dict.update(userdata)
|
||||||
|
torch.save(state_dict, outpath)
|
||||||
|
print(f"Exported {name} to {outpath}")
|
||||||
|
|
||||||
def save_checkpoint(self, tag=None):
|
def save_checkpoint(self, tag=None):
|
||||||
if not tag:
|
if not tag:
|
||||||
tag = cfg.trainer.save_tag
|
tag = cfg.trainer.save_tag
|
||||||
|
@ -246,7 +261,7 @@ class Engines(dict[str, Engine]):
|
||||||
p.unlink()
|
p.unlink()
|
||||||
d.rmdir()
|
d.rmdir()
|
||||||
|
|
||||||
def load_checkpoint(self, tag=None):
|
def load_checkpoint(self, tag=None, module_only=False):
|
||||||
if not tag:
|
if not tag:
|
||||||
tag = cfg.trainer.load_tag
|
tag = cfg.trainer.load_tag
|
||||||
|
|
||||||
|
@ -256,8 +271,9 @@ class Engines(dict[str, Engine]):
|
||||||
tag=tag,
|
tag=tag,
|
||||||
load_dir=load_dir,
|
load_dir=load_dir,
|
||||||
load_module_strict=cfg.trainer.strict_loading,
|
load_module_strict=cfg.trainer.strict_loading,
|
||||||
load_optimizer_states=cfg.trainer.load_states,
|
load_optimizer_states=False if module_only else cfg.trainer.load_states,
|
||||||
load_lr_scheduler_states=cfg.trainer.load_states,
|
load_lr_scheduler_states=False if module_only else cfg.trainer.load_states,
|
||||||
|
load_module_only=module_only,
|
||||||
)
|
)
|
||||||
if cfg.trainer.restart_step_count:
|
if cfg.trainer.restart_step_count:
|
||||||
engine.global_steps = 0
|
engine.global_steps = 0
|
||||||
|
|
|
@ -12,16 +12,7 @@ def main():
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
engines = load_engines()
|
engines = load_engines()
|
||||||
for name, engine in engines.items():
|
engines.export(userdata={"symmap": get_phone_symmap()})
|
||||||
outpath = cfg.ckpt_dir / name / "fp32.pth"
|
|
||||||
torch.save({
|
|
||||||
"global_step": engine.global_step,
|
|
||||||
"micro_step": engine.micro_step,
|
|
||||||
'module': engine.module.to('cpu', dtype=torch.float32).state_dict(),
|
|
||||||
#'optimizer': engine.optimizer.state_dict(),
|
|
||||||
'symmap': get_phone_symmap(),
|
|
||||||
}, outpath)
|
|
||||||
print(f"Exported {name} to {outpath}")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
|
@ -33,6 +33,7 @@ from ..models import get_models
|
||||||
|
|
||||||
from .utils import to_device, do_gc
|
from .utils import to_device, do_gc
|
||||||
from ..utils import wrapper as ml
|
from ..utils import wrapper as ml
|
||||||
|
from ..data import get_phone_symmap # should decouple from this trainer script
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
_engines: Engines
|
_engines: Engines
|
||||||
|
@ -69,6 +70,9 @@ def load_engines():
|
||||||
optimizer = None
|
optimizer = None
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
|
|
||||||
|
# yuck, should instead check be optimier == "adamw" and backend != "deepspeed"
|
||||||
|
# and then have ds_cfg pass in the config flag to use pytorch adamw
|
||||||
|
# I genuinely cannot validate if this ever actually gets used in DeepSpeed
|
||||||
if cfg.hyperparameters.optimizer.lower() == "adamw-torch":
|
if cfg.hyperparameters.optimizer.lower() == "adamw-torch":
|
||||||
optimizer = ml.AdamW(
|
optimizer = ml.AdamW(
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
|
@ -86,6 +90,9 @@ def load_engines():
|
||||||
if "module" in state:
|
if "module" in state:
|
||||||
state = state["module"]
|
state = state["module"]
|
||||||
|
|
||||||
|
# should decouple the following from this trainer script
|
||||||
|
# probably with passing a fun that defaults to a lambda x: x deal
|
||||||
|
|
||||||
# extend the proms_emb if we ever touch the n_prom_levels or n_prom_tokens (from adding tasks)
|
# extend the proms_emb if we ever touch the n_prom_levels or n_prom_tokens (from adding tasks)
|
||||||
if model.proms_emb.weight.shape[0] > state['proms_emb.weight'].shape[0] or model.proms_emb.weight.shape[1] > state['proms_emb.weight'].shape[1]:
|
if model.proms_emb.weight.shape[0] > state['proms_emb.weight'].shape[0] or model.proms_emb.weight.shape[1] > state['proms_emb.weight'].shape[1]:
|
||||||
o_prom_levels, o_prom_tokens, d_model = state['proms_emb.weight'].shape
|
o_prom_levels, o_prom_tokens, d_model = state['proms_emb.weight'].shape
|
||||||
|
@ -301,8 +308,18 @@ def train(
|
||||||
|
|
||||||
if "lr" in command:
|
if "lr" in command:
|
||||||
rate = float(command.split(" ")[-1])
|
rate = float(command.split(" ")[-1])
|
||||||
engines.set_lr(rate)
|
try:
|
||||||
print("Updating LR to:", rate)
|
engines.set_lr(rate)
|
||||||
|
print("Updating LR to:", rate)
|
||||||
|
except Exception as e:
|
||||||
|
print("Failed to set LR rate to:", rate, str(e))
|
||||||
|
|
||||||
|
if "export" in command:
|
||||||
|
engines.save_checkpoint()
|
||||||
|
last_save_step = engines.global_step
|
||||||
|
|
||||||
|
if is_global_leader():
|
||||||
|
engines.export(userdata={"symmap": get_phone_symmap()})
|
||||||
|
|
||||||
save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency
|
save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user