diff --git a/vall_e/export.py b/vall_e/export.py index 5649f50..f0cb66c 100755 --- a/vall_e/export.py +++ b/vall_e/export.py @@ -127,6 +127,9 @@ def main(): if args.dtype != "auto": cfg.trainer.weight_dtype = args.dtype + # necessary to ensure we are actually exporting the weights right + cfg.inference.backend = cfg.trainer.backend + engines = load_engines(training=False) engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback)