cleanup
This commit is contained in:
parent
b2194b859a
commit
fcac9503e2
vall_e
|
@ -143,11 +143,11 @@ def load_engines(training=True):
|
||||||
del state[k]
|
del state[k]
|
||||||
|
|
||||||
# resize text embedding
|
# resize text embedding
|
||||||
if model.config.text_tokens != state["text_emb.weight"].shape[0]:
|
if "text_emb.weight" in state and model.config.text_tokens != state["text_emb.weight"].shape[0]:
|
||||||
state["text_emb.weight"] = state["text_emb.weight"][:model.config.text_tokens]
|
state["text_emb.weight"] = state["text_emb.weight"][:model.config.text_tokens]
|
||||||
|
|
||||||
# resize text embedding
|
# resize text embedding
|
||||||
if model.config.resp_levels != state["rvq_level_emb.weight"].shape[0]:
|
if "rvq_level_emb.weight" in state and model.config.resp_levels != state["rvq_level_emb.weight"].shape[0]:
|
||||||
state["rvq_level_emb.weight"] = state["rvq_level_emb.weight"][:model.config.resp_levels]
|
state["rvq_level_emb.weight"] = state["rvq_level_emb.weight"][:model.config.resp_levels]
|
||||||
|
|
||||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||||
|
|
|
@ -86,7 +86,7 @@ class Engine():
|
||||||
self._frozen_params.clear()
|
self._frozen_params.clear()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _training(self):
|
def training(self):
|
||||||
if not hasattr(self, "hyper_config"):
|
if not hasattr(self, "hyper_config"):
|
||||||
return True
|
return True
|
||||||
return self.hyper_config.training
|
return self.hyper_config.training
|
||||||
|
@ -308,7 +308,7 @@ class Engines(dict[str, Engine]):
|
||||||
"userdata": userdata
|
"userdata": userdata
|
||||||
}
|
}
|
||||||
if callback:
|
if callback:
|
||||||
state_dict = callback( state_dict, engine.module )
|
state_dict = callback( state_dict, engine.hyper_config )
|
||||||
torch.save(state_dict, outpath)
|
torch.save(state_dict, outpath)
|
||||||
print(f"Exported {name} to {outpath}")
|
print(f"Exported {name} to {outpath}")
|
||||||
|
|
||||||
|
@ -321,7 +321,7 @@ class Engines(dict[str, Engine]):
|
||||||
|
|
||||||
cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
|
cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
|
||||||
for name, engine in self.items():
|
for name, engine in self.items():
|
||||||
if not engine._training:
|
if not engine.training:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
save_dir = cfg.ckpt_dir / name
|
save_dir = cfg.ckpt_dir / name
|
||||||
|
@ -371,7 +371,7 @@ class Engines(dict[str, Engine]):
|
||||||
|
|
||||||
def set_lr(self, lr):
|
def set_lr(self, lr):
|
||||||
for engine in self.values():
|
for engine in self.values():
|
||||||
if not engine._training:
|
if not engine.training:
|
||||||
continue
|
continue
|
||||||
engine.set_lr(lr)
|
engine.set_lr(lr)
|
||||||
|
|
||||||
|
@ -406,7 +406,7 @@ class Engines(dict[str, Engine]):
|
||||||
do_gc()
|
do_gc()
|
||||||
|
|
||||||
for name, engine in self.items():
|
for name, engine in self.items():
|
||||||
if not engine._training:
|
if not engine.training:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
device = engine.device
|
device = engine.device
|
||||||
|
|
|
@ -77,7 +77,7 @@ class Engine(DeepSpeedEngine):
|
||||||
self._frozen_params.clear()
|
self._frozen_params.clear()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _training(self):
|
def training(self):
|
||||||
return self.hyper_config.training
|
return self.hyper_config.training
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -269,6 +269,8 @@ class AR_NAR(Base):
|
||||||
resps_list=resps_list,
|
resps_list=resps_list,
|
||||||
lang_list=lang_list,
|
lang_list=lang_list,
|
||||||
tone_list=tone_list,
|
tone_list=tone_list,
|
||||||
|
|
||||||
|
quant_levels=torch.Tensor( [ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ] ).to( device=device, dtype=torch.int32 ),
|
||||||
)
|
)
|
||||||
|
|
||||||
if recurrent_state is not None:
|
if recurrent_state is not None:
|
||||||
|
|
|
@ -157,6 +157,9 @@ def train(
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
for batch in _make_infinite_epochs(train_dl):
|
for batch in _make_infinite_epochs(train_dl):
|
||||||
|
if not engine.training:
|
||||||
|
continue
|
||||||
|
|
||||||
if engines.global_step >= cfg.trainer.iterations:
|
if engines.global_step >= cfg.trainer.iterations:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user