fix weird regression in handling checkpoints when backend is local, but deepspeed checkpoints are in (it was handled with LoRA loading but not real loading...)

This commit is contained in:
mrq 2024-07-30 22:15:56 -05:00
parent 07f8e2ad06
commit d7c6be6f78
3 changed files with 7 additions and 8 deletions

View File

@ -48,10 +48,10 @@ def load_engines(training=True):
if cfg.lora is not None:
checkpoint_path = cfg.ckpt_dir / cfg.lora.full_name / "latest"
# to handle the issue of training with deepspeed, but inferencing with local
if checkpoint_path.exists() and backend == "local":
tag = open(checkpoint_path).read()
checkpoint_path = cfg.ckpt_dir / cfg.lora.full_name / tag / "state.pth"
# to handle the issue of training with deepspeed, but inferencing with local
if checkpoint_path.exists() and backend == "local":
tag = open(checkpoint_path).read()
checkpoint_path = checkpoint_path.parent / tag / "state.pth"
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
print("Checkpoint missing, but weights found.")

View File

@ -52,6 +52,7 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
state_dict['module']['model.embed_tokens.weight'] = embedding.state_dict()
state_dict['module']['lm_head.weight'] = out_proj
del state_dict['module']['classifier.bias']
return state_dict
@ -130,7 +131,7 @@ def main():
# necessary to ensure we are actually exporting the weights right
cfg.inference.backend = cfg.trainer.backend
engines = load_engines(training=False)
engines = load_engines(training=False) # to ignore loading optimizer state
engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback)
if __name__ == "__main__":

View File

@ -190,9 +190,7 @@ class TTS():
phns = to_device(phns, device=self.device, dtype=torch.uint8 if len(self.symmap) < 256 else torch.int16)
lang = to_device(lang, device=self.device, dtype=torch.uint8)
text_list = [ phns ]
proms_list = [ prom ]
# to-do: add in case for experimental.hf model
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
if model_ar is not None:
resps_list = model_ar(