vall-e/vall_e/export.py

27 lines
753 B
Python
Executable File

import argparse
import torch
from .data import get_phone_symmap
from .train import load_engines
from .config import cfg
def main():
parser = argparse.ArgumentParser("Save trained model to path.")
#parser.add_argument("--yaml", type=Path, default=None)
args = parser.parse_args()
engines = load_engines()
for name, engine in engines.items():
outpath = cfg.ckpt_dir / name / "fp32.pth"
torch.save({
"global_step": engine.global_step,
"micro_step": engine.micro_step,
'module': engine.module.to('cpu', dtype=torch.float32).state_dict(),
#'optimizer': engine.optimizer.state_dict(),
'symmap': get_phone_symmap(),
}, outpath)
print(f"Exported {name} to {outpath}")
if __name__ == "__main__":
main()