2023-08-02 21:53:35 +00:00
|
|
|
import argparse
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
from .data import get_phone_symmap
|
|
|
|
from .train import load_engines
|
2023-08-14 03:56:28 +00:00
|
|
|
from .config import cfg
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
def main():
|
|
|
|
parser = argparse.ArgumentParser("Save trained model to path.")
|
2023-08-14 03:56:28 +00:00
|
|
|
#parser.add_argument("--yaml", type=Path, default=None)
|
2023-08-02 21:53:35 +00:00
|
|
|
args = parser.parse_args()
|
|
|
|
|
2023-08-14 03:56:28 +00:00
|
|
|
engines = load_engines()
|
|
|
|
for name, engine in engines.items():
|
|
|
|
outpath = cfg.ckpt_dir / name / "fp32.pth"
|
2023-08-14 03:07:45 +00:00
|
|
|
torch.save({
|
2023-08-14 03:56:28 +00:00
|
|
|
"global_step": engine.global_step,
|
|
|
|
"micro_step": engine.micro_step,
|
|
|
|
'module': engine.module.to('cpu', dtype=torch.float32).state_dict(),
|
|
|
|
#'optimizer': engine.optimizer.state_dict(),
|
|
|
|
'symmap': get_phone_symmap(),
|
2023-08-14 03:07:45 +00:00
|
|
|
}, outpath)
|
2023-08-02 21:53:35 +00:00
|
|
|
print(f"Exported {name} to {outpath}")
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|