31 lines
626 B
Python
31 lines
626 B
Python
|
import argparse
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from .data import get_symmap
|
||
|
from .train import load_engines
|
||
|
|
||
|
def load_models():
|
||
|
models = {}
|
||
|
engines = load_engines()
|
||
|
for name in engines:
|
||
|
model = engines[name].module.cpu()
|
||
|
models[name] = model
|
||
|
|
||
|
return models
|
||
|
|
||
|
def main():
|
||
|
parser = argparse.ArgumentParser("Save trained model to path.")
|
||
|
parser.add_argument("path")
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
models = load_models()
|
||
|
for name in models:
|
||
|
model = models[name]
|
||
|
|
||
|
outpath = f'{args.path}/{name}.pt'
|
||
|
torch.save(model, outpath)
|
||
|
print(f"Exported {name} to {outpath}")
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|