vall-e/vall_e/export.py

27 lines
753 B
Python
Raw Normal View History

2023-08-02 21:53:35 +00:00
import argparse
import torch
from .data import get_phone_symmap
from .train import load_engines
2023-08-14 03:56:28 +00:00
from .config import cfg
2023-08-02 21:53:35 +00:00
def main():
parser = argparse.ArgumentParser("Save trained model to path.")
2023-08-14 03:56:28 +00:00
#parser.add_argument("--yaml", type=Path, default=None)
2023-08-02 21:53:35 +00:00
args = parser.parse_args()
2023-08-14 03:56:28 +00:00
engines = load_engines()
for name, engine in engines.items():
outpath = cfg.ckpt_dir / name / "fp32.pth"
torch.save({
2023-08-14 03:56:28 +00:00
"global_step": engine.global_step,
"micro_step": engine.micro_step,
'module': engine.module.to('cpu', dtype=torch.float32).state_dict(),
#'optimizer': engine.optimizer.state_dict(),
'symmap': get_phone_symmap(),
}, outpath)
2023-08-02 21:53:35 +00:00
print(f"Exported {name} to {outpath}")
if __name__ == "__main__":
main()