From c8f31db1dea8ac134734da5829a0e8e4710533fa Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 18 Oct 2024 16:58:56 -0500 Subject: [PATCH] default to greedy sample AR (i should probably test this more but it seems to pass my harvard sentences and tongue twisters) --- vall_e/__main__.py | 2 +- vall_e/config.py | 2 +- vall_e/demo.py | 4 ++-- vall_e/inference.py | 8 ++++---- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/vall_e/__main__.py b/vall_e/__main__.py index b7126ab..b979b04 100755 --- a/vall_e/__main__.py +++ b/vall_e/__main__.py @@ -21,7 +21,7 @@ def main(): parser.add_argument("--max-ar-steps", type=int, default=12 * cfg.dataset.frames_per_second) parser.add_argument("--max-nar-levels", type=int, default=7) - parser.add_argument("--ar-temp", type=float, default=1.0) + parser.add_argument("--ar-temp", type=float, default=0.0) parser.add_argument("--nar-temp", type=float, default=0.01) parser.add_argument("--min-ar-temp", type=float, default=-1.0) parser.add_argument("--min-nar-temp", type=float, default=-1.0) diff --git a/vall_e/config.py b/vall_e/config.py index eab5d7f..91c77be 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -429,7 +429,7 @@ class Evaluation: size: int = 64 # number of samples to generate during eval / val steps: int = 500 - ar_temperature: float = 1.0 # AR temp for inferencing + ar_temperature: float = 0.0 # AR temp for inferencing nar_temperature: float = 0.0 # NAR temp for inferencing nar_levels: int = 0 # maximum NAR levels to use for inferencing diff --git a/vall_e/demo.py b/vall_e/demo.py index 9c65ebf..1329af1 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -58,7 +58,7 @@ def main(): parser.add_argument("--max-ar-steps", type=int, default=12 * cfg.dataset.frames_per_second) parser.add_argument("--max-nar-levels", type=int, default=7) - parser.add_argument("--ar-temp", type=float, default=1.0) + parser.add_argument("--ar-temp", type=float, default=0.0) parser.add_argument("--nar-temp", type=float, default=0.0) parser.add_argument("--min-ar-temp", type=float, default=-1.0) parser.add_argument("--min-nar-temp", type=float, default=-1.0) @@ -155,7 +155,7 @@ def main(): elif args.comparison == "dtype": current_dtype = cfg.inference.weight_dtype other_dtype = "float32" - + if current_dtype == "float16": other_dtype = "bfloat16" elif current_dtype == "bfloat16": diff --git a/vall_e/inference.py b/vall_e/inference.py index a530813..aef6b02 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -185,11 +185,11 @@ class TTS(): input_prompt_length=0.0, input_prompt_prefix=False, # - ar_temp=0.95, - nar_temp=0.5, + ar_temp=0.0, + nar_temp=0.0, # - min_ar_temp=0.95, - min_nar_temp=0.5, + min_ar_temp=0.0, + min_nar_temp=0.0, # top_p=1.0, top_k=0,