diff --git a/src/cli.py b/src/cli.py index d463edc..b3d7ddb 100755 --- a/src/cli.py +++ b/src/cli.py @@ -32,6 +32,7 @@ if __name__ == "__main__": parser.add_argument("--breathing_room", default=default_arguments['breathing_room']) parser.add_argument("--cvvp_weight", default=default_arguments['cvvp_weight']) parser.add_argument("--top_p", default=default_arguments['top_p']) + parser.add_argument("--autoregressive_model", default=default_arguments['autoregressive_model']) parser.add_argument("--diffusion_temperature", default=default_arguments['diffusion_temperature']) parser.add_argument("--length_penalty", default=default_arguments['length_penalty']) parser.add_argument("--repetition_penalty", default=default_arguments['repetition_penalty']) @@ -55,6 +56,7 @@ if __name__ == "__main__": 'breathing_room': args.breathing_room, 'cvvp_weight': args.cvvp_weight, 'top_p': args.top_p, + 'autoregressive_model': args.autoregressive_model, 'diffusion_temperature': args.diffusion_temperature, 'length_penalty': args.length_penalty, 'repetition_penalty': args.repetition_penalty, @@ -62,5 +64,6 @@ if __name__ == "__main__": 'experimentals': default_arguments['experimentals'], } - tts = load_tts() + # cli should rely on generate() for loading the TTS backend + #tts = load_tts() generate(**kwargs) \ No newline at end of file