This commit is contained in:
mrq 2024-06-06 13:08:02 -05:00
parent b2194b859a
commit fcac9503e2
5 changed files with 13 additions and 8 deletions

View File

@ -143,11 +143,11 @@ def load_engines(training=True):
del state[k]
# resize text embedding
if model.config.text_tokens != state["text_emb.weight"].shape[0]:
if "text_emb.weight" in state and model.config.text_tokens != state["text_emb.weight"].shape[0]:
state["text_emb.weight"] = state["text_emb.weight"][:model.config.text_tokens]
# resize text embedding
if model.config.resp_levels != state["rvq_level_emb.weight"].shape[0]:
if "rvq_level_emb.weight" in state and model.config.resp_levels != state["rvq_level_emb.weight"].shape[0]:
state["rvq_level_emb.weight"] = state["rvq_level_emb.weight"][:model.config.resp_levels]
model.load_state_dict(state, strict=cfg.trainer.strict_loading)

View File

@ -86,7 +86,7 @@ class Engine():
self._frozen_params.clear()
@property
def _training(self):
def training(self):
if not hasattr(self, "hyper_config"):
return True
return self.hyper_config.training
@ -308,7 +308,7 @@ class Engines(dict[str, Engine]):
"userdata": userdata
}
if callback:
state_dict = callback( state_dict, engine.module )
state_dict = callback( state_dict, engine.hyper_config )
torch.save(state_dict, outpath)
print(f"Exported {name} to {outpath}")
@ -321,7 +321,7 @@ class Engines(dict[str, Engine]):
cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
for name, engine in self.items():
if not engine._training:
if not engine.training:
continue
save_dir = cfg.ckpt_dir / name
@ -371,7 +371,7 @@ class Engines(dict[str, Engine]):
def set_lr(self, lr):
for engine in self.values():
if not engine._training:
if not engine.training:
continue
engine.set_lr(lr)
@ -406,7 +406,7 @@ class Engines(dict[str, Engine]):
do_gc()
for name, engine in self.items():
if not engine._training:
if not engine.training:
continue
device = engine.device

View File

@ -77,7 +77,7 @@ class Engine(DeepSpeedEngine):
self._frozen_params.clear()
@property
def _training(self):
def training(self):
return self.hyper_config.training
@property

View File

@ -269,6 +269,8 @@ class AR_NAR(Base):
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
quant_levels=torch.Tensor( [ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ] ).to( device=device, dtype=torch.int32 ),
)
if recurrent_state is not None:

View File

@ -157,6 +157,9 @@ def train(
# Training loop
for batch in _make_infinite_epochs(train_dl):
if not engine.training:
continue
if engines.global_step >= cfg.trainer.iterations:
break