diff --git a/tortoise_tts/export.py b/tortoise_tts/export.py index 8af9c7a..f6d3ce2 100755 --- a/tortoise_tts/export.py +++ b/tortoise_tts/export.py @@ -49,7 +49,10 @@ def main(): if args.dtype != "auto": cfg.trainer.weight_dtype = args.dtype - + + # necessary to ensure we are actually exporting the weights right + cfg.inference.backend if inferencing else cfg.trainer.backend + engines = load_engines(training=False) engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback)