fix weird regression in handling checkpoints when backend is local, but deepspeed checkpoints are in (it was handled with LoRA loading but not real loading...)
This commit is contained in:
parent
07f8e2ad06
commit
d7c6be6f78
|
@ -51,7 +51,7 @@ def load_engines(training=True):
|
|||
# to handle the issue of training with deepspeed, but inferencing with local
|
||||
if checkpoint_path.exists() and backend == "local":
|
||||
tag = open(checkpoint_path).read()
|
||||
checkpoint_path = cfg.ckpt_dir / cfg.lora.full_name / tag / "state.pth"
|
||||
checkpoint_path = checkpoint_path.parent / tag / "state.pth"
|
||||
|
||||
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
|
||||
print("Checkpoint missing, but weights found.")
|
||||
|
|
|
@ -52,6 +52,7 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
|||
|
||||
state_dict['module']['model.embed_tokens.weight'] = embedding.state_dict()
|
||||
state_dict['module']['lm_head.weight'] = out_proj
|
||||
|
||||
del state_dict['module']['classifier.bias']
|
||||
|
||||
return state_dict
|
||||
|
@ -130,7 +131,7 @@ def main():
|
|||
# necessary to ensure we are actually exporting the weights right
|
||||
cfg.inference.backend = cfg.trainer.backend
|
||||
|
||||
engines = load_engines(training=False)
|
||||
engines = load_engines(training=False) # to ignore loading optimizer state
|
||||
engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -190,9 +190,7 @@ class TTS():
|
|||
phns = to_device(phns, device=self.device, dtype=torch.uint8 if len(self.symmap) < 256 else torch.int16)
|
||||
lang = to_device(lang, device=self.device, dtype=torch.uint8)
|
||||
|
||||
text_list = [ phns ]
|
||||
proms_list = [ prom ]
|
||||
|
||||
# to-do: add in case for experimental.hf model
|
||||
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
|
||||
if model_ar is not None:
|
||||
resps_list = model_ar(
|
||||
|
|
Loading…
Reference in New Issue
Block a user