added ability to specify attention backend for CLI and webui (because im tired of editing the yaml)

This commit is contained in:
mrq 2024-08-26 19:33:51 -05:00
parent 0d706ec6a1
commit b7b99a25f1
6 changed files with 34 additions and 16 deletions

View File

@ -45,9 +45,10 @@ def main():
parser.add_argument("--device", type=str, default=None)
parser.add_argument("--amp", action="store_true")
parser.add_argument("--dtype", type=str, default=None)
parser.add_argument("--attention", type=str, default=None)
args = parser.parse_args()
tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp )
tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp, attention=args.attention )
tts.inference(
text=args.text,
references=args.references,

View File

@ -28,8 +28,8 @@ except Exception as e:
from functools import cache
@cache
def load_engines(training=True):
models = get_models(cfg.models, training=training)
def load_engines(training=True, **model_kwargs):
models = get_models(cfg.models, training=training, **model_kwargs)
engines = dict()
for name, model in models.items():

View File

@ -20,15 +20,16 @@ if deepspeed_available:
import deepspeed
class TTS():
def __init__( self, config=None, device=None, amp=None, dtype=None ):
def __init__( self, config=None, device=None, amp=None, dtype=None, attention=None ):
self.loading = True
self.load_config( config=config, device=device, amp=amp, dtype=dtype )
# yes I can just grab **kwargs and forward them here
self.load_config( config=config, device=device, amp=amp, dtype=dtype, attention=attention )
self.load_model()
self.loading = False
def load_config( self, config=None, device=None, amp=None, dtype=None ):
def load_config( self, config=None, device=None, amp=None, dtype=None, attention=None ):
if config:
print("Loading YAML:", config)
cfg.load_yaml( config )
@ -57,12 +58,15 @@ class TTS():
self.dtype = cfg.inference.dtype
self.amp = amp
self.model_kwargs = {}
if attention:
self.model_kwargs["attention"] = attention
def load_model( self ):
load_engines.cache_clear()
unload_model()
self.engines = load_engines(training=False)
self.engines = load_engines(training=False, **self.model_kwargs)
for name, engine in self.engines.items():
if self.dtype != torch.int8:
engine.to(self.device, dtype=self.dtype if not self.amp else torch.float32)

View File

@ -1,5 +1,5 @@
def get_model(config, training=True):
def get_model(config, training=True, **model_kwargs):
name = config.name
if "len" in config.capabilities:
@ -18,6 +18,7 @@ def get_model(config, training=True):
training = training,
config = config,
**model_kwargs
)
elif config.experimental.hf:
from .experimental import Model as Experimental
@ -31,6 +32,7 @@ def get_model(config, training=True):
p_dropout=config.dropout,
config = config,
**model_kwargs
)
else:
from .ar_nar import AR_NAR
@ -48,11 +50,12 @@ def get_model(config, training=True):
training = training,
config = config,
**model_kwargs
)
print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
return model
def get_models(models, training=True):
return { model.full_name: get_model(model, training=training) for model in models }
def get_models(models, training=True, **model_kwargs):
return { model.full_name: get_model(model, training=training, **model_kwargs) for model in models }

View File

@ -385,6 +385,7 @@ class Base(nn.Module):
l_padding: int = 0,
training = True,
attention = None,
config = None,
):
super().__init__()
@ -419,7 +420,10 @@ class Base(nn.Module):
if self.arch_type in ERROR_ARCHES:
raise ERROR_ARCHES[self.arch_type]
attention_backend = self.config.attention if self.config is not None else "auto"
if not attention:
attention = self.config.attention if self.config is not None else "auto"
attention_backend = attention
audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False
split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False
tie_classifier_to_embedding = self.config.experimental.tie_classifier_to_embedding if self.config is not None else False

View File

@ -74,16 +74,20 @@ def get_model_paths( paths=[Path("./training/"), Path("./models/")] ):
def get_dtypes():
return ["float32", "float16", "bfloat16", "float8_e5m2", "float8_e4m3fn", "auto"]
from .models.arch import AVAILABLE_ATTENTIONS
def get_attentions():
return AVAILABLE_ATTENTIONS + ["auto"]
#@gradio_wrapper(inputs=layout["settings"]["inputs"].keys())
def load_model( yaml, device, dtype ):
def load_model( yaml, device, dtype, attention ):
gr.Info(f"Loading: {yaml}")
try:
init_tts( yaml=Path(yaml), restart=True )
init_tts( yaml=Path(yaml), restart=True, device=device, dtype=dtype, attention=attention )
except Exception as e:
raise gr.Error(e)
gr.Info(f"Loaded model")
def init_tts(yaml=None, restart=False, device="cuda", dtype="auto"):
def init_tts(yaml=None, restart=False, device="cuda", dtype="auto", attention="auto"):
global tts
if tts is not None:
@ -98,9 +102,10 @@ def init_tts(yaml=None, restart=False, device="cuda", dtype="auto"):
parser.add_argument("--device", type=str, default=device)
parser.add_argument("--amp", action="store_true")
parser.add_argument("--dtype", type=str, default=dtype)
parser.add_argument("--attention", type=str, default=attention)
args, unknown = parser.parse_known_args()
tts = TTS( config=args.yaml if yaml is None else yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp )
tts = TTS( config=args.yaml if yaml is None else yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp, attention=args.attention )
return tts
@gradio_wrapper(inputs=layout["inference"]["inputs"].keys())
@ -313,8 +318,9 @@ with ui:
with gr.Column(scale=7):
with gr.Row():
layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml, label="Model")
layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda", label="Device")
layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device")
layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision")
layout["settings"]["inputs"]["attentions"] = gr.Dropdown(choices=get_attentions(), value="auto", label="Attentions")
with gr.Column(scale=1):
layout["settings"]["buttons"]["load"] = gr.Button(value="Load Model")