From b7b99a25f11ad31798d0b81d810dd4d0cb0da206 Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 26 Aug 2024 19:33:51 -0500 Subject: [PATCH] added ability to specify attention backend for CLI and webui (because im tired of editing the yaml) --- vall_e/__main__.py | 3 ++- vall_e/engines/__init__.py | 4 ++-- vall_e/inference.py | 12 ++++++++---- vall_e/models/__init__.py | 9 ++++++--- vall_e/models/base.py | 6 +++++- vall_e/webui.py | 16 +++++++++++----- 6 files changed, 34 insertions(+), 16 deletions(-) diff --git a/vall_e/__main__.py b/vall_e/__main__.py index c9a3e43..4c3ea0a 100755 --- a/vall_e/__main__.py +++ b/vall_e/__main__.py @@ -45,9 +45,10 @@ def main(): parser.add_argument("--device", type=str, default=None) parser.add_argument("--amp", action="store_true") parser.add_argument("--dtype", type=str, default=None) + parser.add_argument("--attention", type=str, default=None) args = parser.parse_args() - tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp ) + tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp, attention=args.attention ) tts.inference( text=args.text, references=args.references, diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index a1fd7c9..e012a87 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -28,8 +28,8 @@ except Exception as e: from functools import cache @cache -def load_engines(training=True): - models = get_models(cfg.models, training=training) +def load_engines(training=True, **model_kwargs): + models = get_models(cfg.models, training=training, **model_kwargs) engines = dict() for name, model in models.items(): diff --git a/vall_e/inference.py b/vall_e/inference.py index fa169b5..79a91e0 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -20,15 +20,16 @@ if deepspeed_available: import deepspeed class TTS(): - def __init__( self, config=None, device=None, amp=None, dtype=None ): + def __init__( self, config=None, device=None, amp=None, dtype=None, attention=None ): self.loading = True - self.load_config( config=config, device=device, amp=amp, dtype=dtype ) + # yes I can just grab **kwargs and forward them here + self.load_config( config=config, device=device, amp=amp, dtype=dtype, attention=attention ) self.load_model() self.loading = False - def load_config( self, config=None, device=None, amp=None, dtype=None ): + def load_config( self, config=None, device=None, amp=None, dtype=None, attention=None ): if config: print("Loading YAML:", config) cfg.load_yaml( config ) @@ -57,12 +58,15 @@ class TTS(): self.dtype = cfg.inference.dtype self.amp = amp + self.model_kwargs = {} + if attention: + self.model_kwargs["attention"] = attention def load_model( self ): load_engines.cache_clear() unload_model() - self.engines = load_engines(training=False) + self.engines = load_engines(training=False, **self.model_kwargs) for name, engine in self.engines.items(): if self.dtype != torch.int8: engine.to(self.device, dtype=self.dtype if not self.amp else torch.float32) diff --git a/vall_e/models/__init__.py b/vall_e/models/__init__.py index b6cfa51..4e37356 100755 --- a/vall_e/models/__init__.py +++ b/vall_e/models/__init__.py @@ -1,5 +1,5 @@ -def get_model(config, training=True): +def get_model(config, training=True, **model_kwargs): name = config.name if "len" in config.capabilities: @@ -18,6 +18,7 @@ def get_model(config, training=True): training = training, config = config, + **model_kwargs ) elif config.experimental.hf: from .experimental import Model as Experimental @@ -31,6 +32,7 @@ def get_model(config, training=True): p_dropout=config.dropout, config = config, + **model_kwargs ) else: from .ar_nar import AR_NAR @@ -48,11 +50,12 @@ def get_model(config, training=True): training = training, config = config, + **model_kwargs ) print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters") return model -def get_models(models, training=True): - return { model.full_name: get_model(model, training=training) for model in models } +def get_models(models, training=True, **model_kwargs): + return { model.full_name: get_model(model, training=training, **model_kwargs) for model in models } diff --git a/vall_e/models/base.py b/vall_e/models/base.py index e4fafde..1ee8026 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -385,6 +385,7 @@ class Base(nn.Module): l_padding: int = 0, training = True, + attention = None, config = None, ): super().__init__() @@ -419,7 +420,10 @@ class Base(nn.Module): if self.arch_type in ERROR_ARCHES: raise ERROR_ARCHES[self.arch_type] - attention_backend = self.config.attention if self.config is not None else "auto" + if not attention: + attention = self.config.attention if self.config is not None else "auto" + + attention_backend = attention audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False tie_classifier_to_embedding = self.config.experimental.tie_classifier_to_embedding if self.config is not None else False diff --git a/vall_e/webui.py b/vall_e/webui.py index cf5eecc..0ca05e0 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -74,16 +74,20 @@ def get_model_paths( paths=[Path("./training/"), Path("./models/")] ): def get_dtypes(): return ["float32", "float16", "bfloat16", "float8_e5m2", "float8_e4m3fn", "auto"] +from .models.arch import AVAILABLE_ATTENTIONS +def get_attentions(): + return AVAILABLE_ATTENTIONS + ["auto"] + #@gradio_wrapper(inputs=layout["settings"]["inputs"].keys()) -def load_model( yaml, device, dtype ): +def load_model( yaml, device, dtype, attention ): gr.Info(f"Loading: {yaml}") try: - init_tts( yaml=Path(yaml), restart=True ) + init_tts( yaml=Path(yaml), restart=True, device=device, dtype=dtype, attention=attention ) except Exception as e: raise gr.Error(e) gr.Info(f"Loaded model") -def init_tts(yaml=None, restart=False, device="cuda", dtype="auto"): +def init_tts(yaml=None, restart=False, device="cuda", dtype="auto", attention="auto"): global tts if tts is not None: @@ -98,9 +102,10 @@ def init_tts(yaml=None, restart=False, device="cuda", dtype="auto"): parser.add_argument("--device", type=str, default=device) parser.add_argument("--amp", action="store_true") parser.add_argument("--dtype", type=str, default=dtype) + parser.add_argument("--attention", type=str, default=attention) args, unknown = parser.parse_known_args() - tts = TTS( config=args.yaml if yaml is None else yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp ) + tts = TTS( config=args.yaml if yaml is None else yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp, attention=args.attention ) return tts @gradio_wrapper(inputs=layout["inference"]["inputs"].keys()) @@ -313,8 +318,9 @@ with ui: with gr.Column(scale=7): with gr.Row(): layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml, label="Model") - layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda", label="Device") + layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device") layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision") + layout["settings"]["inputs"]["attentions"] = gr.Dropdown(choices=get_attentions(), value="auto", label="Attentions") with gr.Column(scale=1): layout["settings"]["buttons"]["load"] = gr.Button(value="Load Model")