added ability to specify attention backend for CLI and webui (because im tired of editing the yaml)
This commit is contained in:
parent
0d706ec6a1
commit
b7b99a25f1
|
@ -45,9 +45,10 @@ def main():
|
||||||
parser.add_argument("--device", type=str, default=None)
|
parser.add_argument("--device", type=str, default=None)
|
||||||
parser.add_argument("--amp", action="store_true")
|
parser.add_argument("--amp", action="store_true")
|
||||||
parser.add_argument("--dtype", type=str, default=None)
|
parser.add_argument("--dtype", type=str, default=None)
|
||||||
|
parser.add_argument("--attention", type=str, default=None)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp )
|
tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp, attention=args.attention )
|
||||||
tts.inference(
|
tts.inference(
|
||||||
text=args.text,
|
text=args.text,
|
||||||
references=args.references,
|
references=args.references,
|
||||||
|
|
|
@ -28,8 +28,8 @@ except Exception as e:
|
||||||
from functools import cache
|
from functools import cache
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def load_engines(training=True):
|
def load_engines(training=True, **model_kwargs):
|
||||||
models = get_models(cfg.models, training=training)
|
models = get_models(cfg.models, training=training, **model_kwargs)
|
||||||
engines = dict()
|
engines = dict()
|
||||||
|
|
||||||
for name, model in models.items():
|
for name, model in models.items():
|
||||||
|
|
|
@ -20,15 +20,16 @@ if deepspeed_available:
|
||||||
import deepspeed
|
import deepspeed
|
||||||
|
|
||||||
class TTS():
|
class TTS():
|
||||||
def __init__( self, config=None, device=None, amp=None, dtype=None ):
|
def __init__( self, config=None, device=None, amp=None, dtype=None, attention=None ):
|
||||||
self.loading = True
|
self.loading = True
|
||||||
|
|
||||||
self.load_config( config=config, device=device, amp=amp, dtype=dtype )
|
# yes I can just grab **kwargs and forward them here
|
||||||
|
self.load_config( config=config, device=device, amp=amp, dtype=dtype, attention=attention )
|
||||||
self.load_model()
|
self.load_model()
|
||||||
|
|
||||||
self.loading = False
|
self.loading = False
|
||||||
|
|
||||||
def load_config( self, config=None, device=None, amp=None, dtype=None ):
|
def load_config( self, config=None, device=None, amp=None, dtype=None, attention=None ):
|
||||||
if config:
|
if config:
|
||||||
print("Loading YAML:", config)
|
print("Loading YAML:", config)
|
||||||
cfg.load_yaml( config )
|
cfg.load_yaml( config )
|
||||||
|
@ -57,12 +58,15 @@ class TTS():
|
||||||
self.dtype = cfg.inference.dtype
|
self.dtype = cfg.inference.dtype
|
||||||
self.amp = amp
|
self.amp = amp
|
||||||
|
|
||||||
|
self.model_kwargs = {}
|
||||||
|
if attention:
|
||||||
|
self.model_kwargs["attention"] = attention
|
||||||
|
|
||||||
def load_model( self ):
|
def load_model( self ):
|
||||||
load_engines.cache_clear()
|
load_engines.cache_clear()
|
||||||
unload_model()
|
unload_model()
|
||||||
|
|
||||||
self.engines = load_engines(training=False)
|
self.engines = load_engines(training=False, **self.model_kwargs)
|
||||||
for name, engine in self.engines.items():
|
for name, engine in self.engines.items():
|
||||||
if self.dtype != torch.int8:
|
if self.dtype != torch.int8:
|
||||||
engine.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
engine.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
|
|
||||||
def get_model(config, training=True):
|
def get_model(config, training=True, **model_kwargs):
|
||||||
name = config.name
|
name = config.name
|
||||||
|
|
||||||
if "len" in config.capabilities:
|
if "len" in config.capabilities:
|
||||||
|
@ -18,6 +18,7 @@ def get_model(config, training=True):
|
||||||
|
|
||||||
training = training,
|
training = training,
|
||||||
config = config,
|
config = config,
|
||||||
|
**model_kwargs
|
||||||
)
|
)
|
||||||
elif config.experimental.hf:
|
elif config.experimental.hf:
|
||||||
from .experimental import Model as Experimental
|
from .experimental import Model as Experimental
|
||||||
|
@ -31,6 +32,7 @@ def get_model(config, training=True):
|
||||||
p_dropout=config.dropout,
|
p_dropout=config.dropout,
|
||||||
|
|
||||||
config = config,
|
config = config,
|
||||||
|
**model_kwargs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
from .ar_nar import AR_NAR
|
from .ar_nar import AR_NAR
|
||||||
|
@ -48,11 +50,12 @@ def get_model(config, training=True):
|
||||||
|
|
||||||
training = training,
|
training = training,
|
||||||
config = config,
|
config = config,
|
||||||
|
**model_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
|
print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def get_models(models, training=True):
|
def get_models(models, training=True, **model_kwargs):
|
||||||
return { model.full_name: get_model(model, training=training) for model in models }
|
return { model.full_name: get_model(model, training=training, **model_kwargs) for model in models }
|
||||||
|
|
|
@ -385,6 +385,7 @@ class Base(nn.Module):
|
||||||
l_padding: int = 0,
|
l_padding: int = 0,
|
||||||
|
|
||||||
training = True,
|
training = True,
|
||||||
|
attention = None,
|
||||||
config = None,
|
config = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -419,7 +420,10 @@ class Base(nn.Module):
|
||||||
if self.arch_type in ERROR_ARCHES:
|
if self.arch_type in ERROR_ARCHES:
|
||||||
raise ERROR_ARCHES[self.arch_type]
|
raise ERROR_ARCHES[self.arch_type]
|
||||||
|
|
||||||
attention_backend = self.config.attention if self.config is not None else "auto"
|
if not attention:
|
||||||
|
attention = self.config.attention if self.config is not None else "auto"
|
||||||
|
|
||||||
|
attention_backend = attention
|
||||||
audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False
|
audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False
|
||||||
split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False
|
split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False
|
||||||
tie_classifier_to_embedding = self.config.experimental.tie_classifier_to_embedding if self.config is not None else False
|
tie_classifier_to_embedding = self.config.experimental.tie_classifier_to_embedding if self.config is not None else False
|
||||||
|
|
|
@ -74,16 +74,20 @@ def get_model_paths( paths=[Path("./training/"), Path("./models/")] ):
|
||||||
def get_dtypes():
|
def get_dtypes():
|
||||||
return ["float32", "float16", "bfloat16", "float8_e5m2", "float8_e4m3fn", "auto"]
|
return ["float32", "float16", "bfloat16", "float8_e5m2", "float8_e4m3fn", "auto"]
|
||||||
|
|
||||||
|
from .models.arch import AVAILABLE_ATTENTIONS
|
||||||
|
def get_attentions():
|
||||||
|
return AVAILABLE_ATTENTIONS + ["auto"]
|
||||||
|
|
||||||
#@gradio_wrapper(inputs=layout["settings"]["inputs"].keys())
|
#@gradio_wrapper(inputs=layout["settings"]["inputs"].keys())
|
||||||
def load_model( yaml, device, dtype ):
|
def load_model( yaml, device, dtype, attention ):
|
||||||
gr.Info(f"Loading: {yaml}")
|
gr.Info(f"Loading: {yaml}")
|
||||||
try:
|
try:
|
||||||
init_tts( yaml=Path(yaml), restart=True )
|
init_tts( yaml=Path(yaml), restart=True, device=device, dtype=dtype, attention=attention )
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise gr.Error(e)
|
raise gr.Error(e)
|
||||||
gr.Info(f"Loaded model")
|
gr.Info(f"Loaded model")
|
||||||
|
|
||||||
def init_tts(yaml=None, restart=False, device="cuda", dtype="auto"):
|
def init_tts(yaml=None, restart=False, device="cuda", dtype="auto", attention="auto"):
|
||||||
global tts
|
global tts
|
||||||
|
|
||||||
if tts is not None:
|
if tts is not None:
|
||||||
|
@ -98,9 +102,10 @@ def init_tts(yaml=None, restart=False, device="cuda", dtype="auto"):
|
||||||
parser.add_argument("--device", type=str, default=device)
|
parser.add_argument("--device", type=str, default=device)
|
||||||
parser.add_argument("--amp", action="store_true")
|
parser.add_argument("--amp", action="store_true")
|
||||||
parser.add_argument("--dtype", type=str, default=dtype)
|
parser.add_argument("--dtype", type=str, default=dtype)
|
||||||
|
parser.add_argument("--attention", type=str, default=attention)
|
||||||
args, unknown = parser.parse_known_args()
|
args, unknown = parser.parse_known_args()
|
||||||
|
|
||||||
tts = TTS( config=args.yaml if yaml is None else yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp )
|
tts = TTS( config=args.yaml if yaml is None else yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp, attention=args.attention )
|
||||||
return tts
|
return tts
|
||||||
|
|
||||||
@gradio_wrapper(inputs=layout["inference"]["inputs"].keys())
|
@gradio_wrapper(inputs=layout["inference"]["inputs"].keys())
|
||||||
|
@ -313,8 +318,9 @@ with ui:
|
||||||
with gr.Column(scale=7):
|
with gr.Column(scale=7):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml, label="Model")
|
layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml, label="Model")
|
||||||
layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda", label="Device")
|
layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device")
|
||||||
layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision")
|
layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision")
|
||||||
|
layout["settings"]["inputs"]["attentions"] = gr.Dropdown(choices=get_attentions(), value="auto", label="Attentions")
|
||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
layout["settings"]["buttons"]["load"] = gr.Button(value="Load Model")
|
layout["settings"]["buttons"]["load"] = gr.Button(value="Load Model")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user