added ability to specify attention backend for CLI and webui (because im tired of editing the yaml)
This commit is contained in:
parent
0d706ec6a1
commit
b7b99a25f1
|
@ -45,9 +45,10 @@ def main():
|
|||
parser.add_argument("--device", type=str, default=None)
|
||||
parser.add_argument("--amp", action="store_true")
|
||||
parser.add_argument("--dtype", type=str, default=None)
|
||||
parser.add_argument("--attention", type=str, default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp )
|
||||
tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp, attention=args.attention )
|
||||
tts.inference(
|
||||
text=args.text,
|
||||
references=args.references,
|
||||
|
|
|
@ -28,8 +28,8 @@ except Exception as e:
|
|||
from functools import cache
|
||||
|
||||
@cache
|
||||
def load_engines(training=True):
|
||||
models = get_models(cfg.models, training=training)
|
||||
def load_engines(training=True, **model_kwargs):
|
||||
models = get_models(cfg.models, training=training, **model_kwargs)
|
||||
engines = dict()
|
||||
|
||||
for name, model in models.items():
|
||||
|
|
|
@ -20,15 +20,16 @@ if deepspeed_available:
|
|||
import deepspeed
|
||||
|
||||
class TTS():
|
||||
def __init__( self, config=None, device=None, amp=None, dtype=None ):
|
||||
def __init__( self, config=None, device=None, amp=None, dtype=None, attention=None ):
|
||||
self.loading = True
|
||||
|
||||
self.load_config( config=config, device=device, amp=amp, dtype=dtype )
|
||||
# yes I can just grab **kwargs and forward them here
|
||||
self.load_config( config=config, device=device, amp=amp, dtype=dtype, attention=attention )
|
||||
self.load_model()
|
||||
|
||||
self.loading = False
|
||||
|
||||
def load_config( self, config=None, device=None, amp=None, dtype=None ):
|
||||
def load_config( self, config=None, device=None, amp=None, dtype=None, attention=None ):
|
||||
if config:
|
||||
print("Loading YAML:", config)
|
||||
cfg.load_yaml( config )
|
||||
|
@ -57,12 +58,15 @@ class TTS():
|
|||
self.dtype = cfg.inference.dtype
|
||||
self.amp = amp
|
||||
|
||||
self.model_kwargs = {}
|
||||
if attention:
|
||||
self.model_kwargs["attention"] = attention
|
||||
|
||||
def load_model( self ):
|
||||
load_engines.cache_clear()
|
||||
unload_model()
|
||||
|
||||
self.engines = load_engines(training=False)
|
||||
self.engines = load_engines(training=False, **self.model_kwargs)
|
||||
for name, engine in self.engines.items():
|
||||
if self.dtype != torch.int8:
|
||||
engine.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
|
||||
def get_model(config, training=True):
|
||||
def get_model(config, training=True, **model_kwargs):
|
||||
name = config.name
|
||||
|
||||
if "len" in config.capabilities:
|
||||
|
@ -18,6 +18,7 @@ def get_model(config, training=True):
|
|||
|
||||
training = training,
|
||||
config = config,
|
||||
**model_kwargs
|
||||
)
|
||||
elif config.experimental.hf:
|
||||
from .experimental import Model as Experimental
|
||||
|
@ -31,6 +32,7 @@ def get_model(config, training=True):
|
|||
p_dropout=config.dropout,
|
||||
|
||||
config = config,
|
||||
**model_kwargs
|
||||
)
|
||||
else:
|
||||
from .ar_nar import AR_NAR
|
||||
|
@ -48,11 +50,12 @@ def get_model(config, training=True):
|
|||
|
||||
training = training,
|
||||
config = config,
|
||||
**model_kwargs
|
||||
)
|
||||
|
||||
print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
|
||||
|
||||
return model
|
||||
|
||||
def get_models(models, training=True):
|
||||
return { model.full_name: get_model(model, training=training) for model in models }
|
||||
def get_models(models, training=True, **model_kwargs):
|
||||
return { model.full_name: get_model(model, training=training, **model_kwargs) for model in models }
|
||||
|
|
|
@ -385,6 +385,7 @@ class Base(nn.Module):
|
|||
l_padding: int = 0,
|
||||
|
||||
training = True,
|
||||
attention = None,
|
||||
config = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
@ -419,7 +420,10 @@ class Base(nn.Module):
|
|||
if self.arch_type in ERROR_ARCHES:
|
||||
raise ERROR_ARCHES[self.arch_type]
|
||||
|
||||
attention_backend = self.config.attention if self.config is not None else "auto"
|
||||
if not attention:
|
||||
attention = self.config.attention if self.config is not None else "auto"
|
||||
|
||||
attention_backend = attention
|
||||
audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False
|
||||
split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False
|
||||
tie_classifier_to_embedding = self.config.experimental.tie_classifier_to_embedding if self.config is not None else False
|
||||
|
|
|
@ -74,16 +74,20 @@ def get_model_paths( paths=[Path("./training/"), Path("./models/")] ):
|
|||
def get_dtypes():
|
||||
return ["float32", "float16", "bfloat16", "float8_e5m2", "float8_e4m3fn", "auto"]
|
||||
|
||||
from .models.arch import AVAILABLE_ATTENTIONS
|
||||
def get_attentions():
|
||||
return AVAILABLE_ATTENTIONS + ["auto"]
|
||||
|
||||
#@gradio_wrapper(inputs=layout["settings"]["inputs"].keys())
|
||||
def load_model( yaml, device, dtype ):
|
||||
def load_model( yaml, device, dtype, attention ):
|
||||
gr.Info(f"Loading: {yaml}")
|
||||
try:
|
||||
init_tts( yaml=Path(yaml), restart=True )
|
||||
init_tts( yaml=Path(yaml), restart=True, device=device, dtype=dtype, attention=attention )
|
||||
except Exception as e:
|
||||
raise gr.Error(e)
|
||||
gr.Info(f"Loaded model")
|
||||
|
||||
def init_tts(yaml=None, restart=False, device="cuda", dtype="auto"):
|
||||
def init_tts(yaml=None, restart=False, device="cuda", dtype="auto", attention="auto"):
|
||||
global tts
|
||||
|
||||
if tts is not None:
|
||||
|
@ -98,9 +102,10 @@ def init_tts(yaml=None, restart=False, device="cuda", dtype="auto"):
|
|||
parser.add_argument("--device", type=str, default=device)
|
||||
parser.add_argument("--amp", action="store_true")
|
||||
parser.add_argument("--dtype", type=str, default=dtype)
|
||||
parser.add_argument("--attention", type=str, default=attention)
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
tts = TTS( config=args.yaml if yaml is None else yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp )
|
||||
tts = TTS( config=args.yaml if yaml is None else yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp, attention=args.attention )
|
||||
return tts
|
||||
|
||||
@gradio_wrapper(inputs=layout["inference"]["inputs"].keys())
|
||||
|
@ -313,8 +318,9 @@ with ui:
|
|||
with gr.Column(scale=7):
|
||||
with gr.Row():
|
||||
layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml, label="Model")
|
||||
layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda", label="Device")
|
||||
layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device")
|
||||
layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision")
|
||||
layout["settings"]["inputs"]["attentions"] = gr.Dropdown(choices=get_attentions(), value="auto", label="Attentions")
|
||||
with gr.Column(scale=1):
|
||||
layout["settings"]["buttons"]["load"] = gr.Button(value="Load Model")
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user