cleanup
This commit is contained in:
parent
b2194b859a
commit
fcac9503e2
|
@ -143,11 +143,11 @@ def load_engines(training=True):
|
|||
del state[k]
|
||||
|
||||
# resize text embedding
|
||||
if model.config.text_tokens != state["text_emb.weight"].shape[0]:
|
||||
if "text_emb.weight" in state and model.config.text_tokens != state["text_emb.weight"].shape[0]:
|
||||
state["text_emb.weight"] = state["text_emb.weight"][:model.config.text_tokens]
|
||||
|
||||
# resize text embedding
|
||||
if model.config.resp_levels != state["rvq_level_emb.weight"].shape[0]:
|
||||
if "rvq_level_emb.weight" in state and model.config.resp_levels != state["rvq_level_emb.weight"].shape[0]:
|
||||
state["rvq_level_emb.weight"] = state["rvq_level_emb.weight"][:model.config.resp_levels]
|
||||
|
||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||
|
|
|
@ -86,7 +86,7 @@ class Engine():
|
|||
self._frozen_params.clear()
|
||||
|
||||
@property
|
||||
def _training(self):
|
||||
def training(self):
|
||||
if not hasattr(self, "hyper_config"):
|
||||
return True
|
||||
return self.hyper_config.training
|
||||
|
@ -308,7 +308,7 @@ class Engines(dict[str, Engine]):
|
|||
"userdata": userdata
|
||||
}
|
||||
if callback:
|
||||
state_dict = callback( state_dict, engine.module )
|
||||
state_dict = callback( state_dict, engine.hyper_config )
|
||||
torch.save(state_dict, outpath)
|
||||
print(f"Exported {name} to {outpath}")
|
||||
|
||||
|
@ -321,7 +321,7 @@ class Engines(dict[str, Engine]):
|
|||
|
||||
cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
|
||||
for name, engine in self.items():
|
||||
if not engine._training:
|
||||
if not engine.training:
|
||||
continue
|
||||
|
||||
save_dir = cfg.ckpt_dir / name
|
||||
|
@ -371,7 +371,7 @@ class Engines(dict[str, Engine]):
|
|||
|
||||
def set_lr(self, lr):
|
||||
for engine in self.values():
|
||||
if not engine._training:
|
||||
if not engine.training:
|
||||
continue
|
||||
engine.set_lr(lr)
|
||||
|
||||
|
@ -406,7 +406,7 @@ class Engines(dict[str, Engine]):
|
|||
do_gc()
|
||||
|
||||
for name, engine in self.items():
|
||||
if not engine._training:
|
||||
if not engine.training:
|
||||
continue
|
||||
|
||||
device = engine.device
|
||||
|
|
|
@ -77,7 +77,7 @@ class Engine(DeepSpeedEngine):
|
|||
self._frozen_params.clear()
|
||||
|
||||
@property
|
||||
def _training(self):
|
||||
def training(self):
|
||||
return self.hyper_config.training
|
||||
|
||||
@property
|
||||
|
|
|
@ -269,6 +269,8 @@ class AR_NAR(Base):
|
|||
resps_list=resps_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
|
||||
quant_levels=torch.Tensor( [ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ] ).to( device=device, dtype=torch.int32 ),
|
||||
)
|
||||
|
||||
if recurrent_state is not None:
|
||||
|
|
|
@ -157,6 +157,9 @@ def train(
|
|||
|
||||
# Training loop
|
||||
for batch in _make_infinite_epochs(train_dl):
|
||||
if not engine.training:
|
||||
continue
|
||||
|
||||
if engines.global_step >= cfg.trainer.iterations:
|
||||
break
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user