allow loading a different model within the web ui (apparently I did not have the web UI in the documentation)

This commit is contained in:
mrq 2024-07-15 19:59:48 -05:00
parent 7b210d9738
commit 3acc54df22
9 changed files with 129 additions and 67 deletions

View File

@ -20,7 +20,6 @@ Simply run `pip install git+https://git.ecker.tech/mrq/vall-e` or `pip install g
I've tested this repo under Python versions `3.10.9`, `3.11.3`, and `3.12.3`. I've tested this repo under Python versions `3.10.9`, `3.11.3`, and `3.12.3`.
## Pre-Trained Model ## Pre-Trained Model
> [!NOTE] > [!NOTE]
@ -196,11 +195,34 @@ And some experimental sampling flags you can use too (your mileage will ***defin
+ **!**NOTE**!**: This is incompatible with beam search sampling (for the meantime at least). + **!**NOTE**!**: This is incompatible with beam search sampling (for the meantime at least).
* `--mirostat-eta`: (AR only) the "learning rate" during mirostat sampling applied to the maximum surprise. * `--mirostat-eta`: (AR only) the "learning rate" during mirostat sampling applied to the maximum surprise.
### Web UI
A Gradio-based web UI is accessible by running `python3 -m vall_e.webui`. You can, optionally, pass:
* `--yaml=./path/to/your/config.yaml`: will load the targeted YAML
* `--listen 0.0.0.0:7860`: will set the web UI to listen to all IPs at port 7860. Replace the IP and Port to your preference.
#### Inference
Synthesizing speech is simple:
* `Input Prompt`: The guiding text prompt. Each new line will be it's own generated audio to be stitched together at the end.
* `Audio Input`: The reference audio for the synthesis. Under Gradio, you can trim your clip accordingly, but leaving it as-is works fine.
* `Output`: The resultant audio.
* `Inference`: Button to start generating the audio.
All the additional knobs have a description that can be correlated to the above CLI flags.
#### Settings
So far, this only allows you to load a different model without needing to restart. The previous model should seamlessly unload, and the new one will load in place.
## To-Do ## To-Do
* [x] train and release a serviceable model for finetuning against. * [x] train and release a serviceable model for finetuning against.
* [ ] train and release a ***good*** zero-shot model. * [ ] train and release a ***good*** zero-shot model.
- this should, hopefully, just simply requires another epoch or two for `ar+nar-llama-8`, as the foundation seems rather robust now. - this should, hopefully, just simply requires another epoch or two for `ar+nar-llama-8`, as the foundation seems rather robust now.
* [ ] well-integrated training through the Web UI (without the kludge from ai-voice-cloning)
* [x] ~~explore alternative setups, like a NAR-only model~~ * [x] ~~explore alternative setups, like a NAR-only model~~
- the current experiment of an AR length-predictor + NAR for the rest seems to fall apart... - the current experiment of an AR length-predictor + NAR for the rest seems to fall apart...
* [x] ~~explore better sampling techniques~~ * [x] ~~explore better sampling techniques~~

View File

@ -210,11 +210,10 @@ class ModelExperimentalSettings:
# I really need to clean this up # I really need to clean this up
@dataclass() @dataclass()
class Model: class Model:
name: str = "" # vanity name for the model name: str = "ar+nar" # vanity name for the model
version: int = 5 # 1 = old with MultiEmbedding, 2 = new with AudioEmbedding, 3+ = additional embeddings version: int = 5 # 1 = old with MultiEmbedding, 2 = new with AudioEmbedding, 3+ = additional embeddings
size: str | dict = "full" # preset string or explicitly defined dimensionality size: str | dict = "full" # preset string or explicitly defined dimensionality
resp_levels: int = 1 # RVQ-bin levels this model targets for outputs resp_levels: int = 8 # RVQ-bin levels this model supports
prom_levels: int = 8 # RVQ-bin levels this model accepts as an input prompt
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc") (unused) tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc") (unused)
langs: int = 1 # defined languages (semi-unused) langs: int = 1 # defined languages (semi-unused)
tones: int = 1 # defined tones (unsued) tones: int = 1 # defined tones (unsued)
@ -238,7 +237,10 @@ class Model:
@property @property
def max_levels(self): def max_levels(self):
return max(self.prom_levels, self.resp_levels) # return RVQ level range
if self.experimental is not None and self.experimental.rvq_level_range:
return self.experimental.rvq_level_range[-1]
return self.resp_levels
@property @property
# required for fp8 as the lengths needs to be divisible by 8 # required for fp8 as the lengths needs to be divisible by 8
@ -626,7 +628,7 @@ class Inference:
use_encodec: bool = True use_encodec: bool = True
use_dac: bool = True use_dac: bool = True
@cached_property @property
def dtype(self): def dtype(self):
if self.weight_dtype == "float16": if self.weight_dtype == "float16":
return torch.float16 return torch.float16
@ -651,7 +653,7 @@ class Optimizations:
optimizers: bool = True # inject/replace optimizers (BnB, DAdaptation) optimizers: bool = True # inject/replace optimizers (BnB, DAdaptation)
bitsandbytes: bool = False # use bitsandbytes bitsandbytes: bool = False # use bitsandbytes
dadaptation: bool = True # use dadaptation optimizer dadaptation: bool = False # use dadaptation optimizer
bitnet: bool = False # use bitnet bitnet: bool = False # use bitnet
fp8: bool = False # use fp8 fp8: bool = False # use fp8
@ -671,7 +673,8 @@ class Config(BaseConfig):
bitsandbytes: dict | list | None = None # deprecated bitsandbytes: dict | list | None = None # deprecated
optimizations: Optimizations = field(default_factory=lambda: Optimizations) optimizations: Optimizations = field(default_factory=lambda: Optimizations)
tokenizer: str = "./tokenizer.json" tokenizer: str | None = None
tokenizer_path: str = "./tokenizer.json"
sample_rate: int = 24_000 sample_rate: int = 24_000
variable_sample_rate: bool = False # NOT recommended, as running directly 24Khz audio in the 44Khz DAC model will have detrimental quality loss variable_sample_rate: bool = False # NOT recommended, as running directly 24Khz audio in the 44Khz DAC model will have detrimental quality loss
@ -760,14 +763,21 @@ class Config(BaseConfig):
self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ] self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ] self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ]
# do cleanup
for model in self.models: for model in self.models:
if not isinstance( model, dict ): if not isinstance( model, dict ):
continue continue
if "prom_levels" in model:
del model["prom_levels"]
if "interleave" in model:
del model["interleave"]
if "audio_embedding_sums" not in model: if "audio_embedding_sums" not in model:
continue continue
if not model["experimental"]: if "experimental" not in model or not model["experimental"]:
model["experimental"] = {} model["experimental"] = {}
model["experimental"]["audio_embedding_sums"] = model.pop("audio_embedding_sums") model["experimental"]["audio_embedding_sums"] = model.pop("audio_embedding_sums")
@ -837,9 +847,9 @@ class Config(BaseConfig):
try: try:
from transformers import PreTrainedTokenizerFast from transformers import PreTrainedTokenizerFast
tokenizer_path = cfg.rel_path / cfg.tokenizer tokenizer_path = cfg.rel_path / cfg.tokenizer_path
if not tokenizer_path.exists(): if not tokenizer_path.exists():
tokenizer_path = Path("./data/") / cfg.tokenizer tokenizer_path = Path("./data/") / cfg.tokenizer_path
cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path)) cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path))
except Exception as e: except Exception as e:
cfg.tokenizer = NaiveTokenizer() cfg.tokenizer = NaiveTokenizer()

View File

@ -961,8 +961,8 @@ class Dataset(_Dataset):
""" """
# trim to fit to requested prom/resps levels # trim to fit to requested prom/resps levels
proms = proms[:, :cfg.model.prom_levels] proms = proms[:, :cfg.model.resp_levels]
resps = resps[:, :cfg.model.prom_levels] resps = resps[:, :cfg.model.resp_levels]
return dict( return dict(
@ -1466,7 +1466,7 @@ if __name__ == "__main__":
if task not in cfg.dataset.tasks_list: if task not in cfg.dataset.tasks_list:
continue continue
print(text, task, cfg.model.prom_levels) print(text, task, cfg.model.resp_levels)
print( proms.shape, resps.shape ) print( proms.shape, resps.shape )
tokens = 0 tokens = 0

View File

@ -37,7 +37,7 @@ def load_engines(training=True):
optimizer = None optimizer = None
lr_scheduler = None lr_scheduler = None
inferencing = cfg.mode == "inferencing" or not model.config.training inferencing = cfg.mode == "inferencing" or not model.config.training or not training
backend = cfg.inference.backend if inferencing else cfg.trainer.backend backend = cfg.inference.backend if inferencing else cfg.trainer.backend
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
amp = cfg.inference.amp if inferencing else cfg.trainer.amp amp = cfg.inference.amp if inferencing else cfg.trainer.amp

View File

@ -8,10 +8,10 @@ from einops import rearrange
from pathlib import Path from pathlib import Path
from .emb import g2p, qnt from .emb import g2p, qnt
from .emb.qnt import trim, trim_random from .emb.qnt import trim, trim_random, unload_model
from .utils import to_device, set_seed, wrapper as ml from .utils import to_device, set_seed, wrapper as ml
from .config import cfg from .config import cfg, Config
from .models import get_models from .models import get_models
from .engines import load_engines, deepspeed_available from .engines import load_engines, deepspeed_available
from .data import get_phone_symmap, get_lang_symmap, _load_quants, _cleanup_phones, tokenize from .data import get_phone_symmap, get_lang_symmap, _load_quants, _cleanup_phones, tokenize
@ -22,11 +22,15 @@ if deepspeed_available:
class TTS(): class TTS():
def __init__( self, config=None, device=None, amp=None, dtype=None ): def __init__( self, config=None, device=None, amp=None, dtype=None ):
self.loading = True self.loading = True
self.input_sample_rate = 24000
self.output_sample_rate = 24000
self.load_config( config=config, device=device, amp=amp, dtype=dtype )
self.load_model()
self.loading = False
def load_config( self, config=None, device=None, amp=None, dtype=None ):
if config: if config:
print("Loading YAML:", config)
cfg.load_yaml( config ) cfg.load_yaml( config )
try: try:
@ -52,20 +56,20 @@ class TTS():
self.device = device self.device = device
self.dtype = cfg.inference.dtype self.dtype = cfg.inference.dtype
self.amp = amp self.amp = amp
self.symmap = None def load_model( self ):
load_engines.cache_clear()
unload_model()
self.engines = load_engines(training=False) self.engines = load_engines(training=False)
for name, engine in self.engines.items(): for name, engine in self.engines.items():
if self.dtype != torch.int8: if self.dtype != torch.int8:
engine.to(self.device, dtype=self.dtype if not self.amp else torch.float32) engine.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
self.engines.eval() self.engines.eval()
self.symmap = get_phone_symmap()
if self.symmap is None: print("Loaded model")
self.symmap = get_phone_symmap()
self.loading = False
def encode_text( self, text, language="en" ): def encode_text( self, text, language="en" ):
# already a tensor, return it # already a tensor, return it

View File

@ -33,12 +33,6 @@ class AR_NAR(Base):
def causal(self): def causal(self):
return "ar" in self.capabilities return "ar" in self.capabilities
@property
def n_prom_levels(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.prom_levels
return cfg.model.prom_levels
@property @property
def n_resp_levels(self) -> int: def n_resp_levels(self) -> int:
if hasattr(self, "config") and self.config: if hasattr(self, "config") and self.config:
@ -197,7 +191,7 @@ class AR_NAR(Base):
# is NAR # is NAR
if max_levels == 0: if max_levels == 0:
max_levels = self.n_resp_levels - 1 max_levels = self.n_max_levels - 1
# expand if given a raw 1D tensor # expand if given a raw 1D tensor
for i, resp in enumerate(resps_list): for i, resp in enumerate(resps_list):
@ -373,7 +367,6 @@ def example_usage():
# mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it) # mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it)
""" """
if "mamba" in cfg.model.arch_type: if "mamba" in cfg.model.arch_type:
cfg.model.prom_levels = 1
cfg.model.resp_levels = 1 cfg.model.resp_levels = 1
""" """
# cfg.model.loss_factors = {} # cfg.model.loss_factors = {}
@ -383,7 +376,7 @@ def example_usage():
def _load_quants(path) -> Tensor: def _load_quants(path) -> Tensor:
qnt = np.load(path, allow_pickle=True)[()] qnt = np.load(path, allow_pickle=True)[()]
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.prom_levels, :].t().to(torch.int16) return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.resp_levels, :].t().to(torch.int16)
qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}") qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")

View File

@ -123,7 +123,7 @@ class AudioEmbedding_Old(nn.Module):
): ):
super().__init__() super().__init__()
# array of embeddings # array of embeddings
# proms are [0, prom_levels] # proms are [0, resp_levels]
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR # resp are split to where [0] is for the AR, and [1:] are reserved for NAR
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens]) self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this) # weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
@ -154,7 +154,7 @@ class AudioEmbedding(nn.Module):
): ):
super().__init__() super().__init__()
# array of embeddings # array of embeddings
# proms are [0, prom_levels] # proms are [0, resp_levels]
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR # resp are split to where [0] is for the AR, and [1:] are reserved for NAR
# + resps cannot share the AR and NAR embeddings, since they do encode whether to predict the same level but in the next token or predict in place but the next level # + resps cannot share the AR and NAR embeddings, since they do encode whether to predict the same level but in the next token or predict in place but the next level
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens]) self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
@ -282,10 +282,6 @@ class Base(nn.Module):
def causal(self) -> bool: def causal(self) -> bool:
raise NotImplementedError raise NotImplementedError
@property
def n_prom_levels(self) -> int:
raise NotImplementedError
@property @property
def n_resp_levels(self) -> int: def n_resp_levels(self) -> int:
raise NotImplementedError raise NotImplementedError
@ -368,7 +364,6 @@ class Base(nn.Module):
self.l_padding = l_padding self.l_padding = l_padding
n_prom_tokens = n_audio_tokens
arch_type = self.config.arch_type if self.config is not None else "llama" arch_type = self.config.arch_type if self.config is not None else "llama"
self.arch_type = arch_type self.arch_type = arch_type
@ -397,14 +392,14 @@ class Base(nn.Module):
self.len_emb = None self.len_emb = None
if self.version == 1: # legacy if self.version == 1: # legacy
n_prom_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom n_audio_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model) self.proms_emb = MultiEmbedding(self.n_resp_levels, n_audio_tokens, d_model)
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic) self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
elif self.version < 5: elif self.version < 5:
# [1024] * 8 # [1024] * 8
self.proms_emb = AudioEmbedding_Old( self.proms_emb = AudioEmbedding_Old(
[n_prom_tokens] * self.n_prom_levels, d_model, [n_audio_tokens] * self.n_resp_levels, d_model,
levels=self.n_prom_levels if self.version > 3 else None, levels=self.n_resp_levels if self.version > 3 else None,
) )
# [1024 + STOP] + [1024] * 8 # [1024 + STOP] + [1024] * 8
self.resps_emb = AudioEmbedding_Old( self.resps_emb = AudioEmbedding_Old(
@ -413,7 +408,7 @@ class Base(nn.Module):
) )
else: else:
self.proms_emb = AudioEmbedding( self.proms_emb = AudioEmbedding(
[n_prom_tokens] * self.n_prom_levels, d_model, [n_audio_tokens] * self.n_resp_levels, d_model,
sums=audio_embedding_sums, sums=audio_embedding_sums,
external_mode=audio_embedding_mode, external_mode=audio_embedding_mode,
) )

View File

@ -31,12 +31,6 @@ class NAR(Base):
def causal(self): def causal(self):
return "len" in self.capabilities return "len" in self.capabilities
@property
def n_prom_levels(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.prom_levels
return cfg.model.prom_levels
@property @property
def n_resp_levels(self) -> int: def n_resp_levels(self) -> int:
if hasattr(self, "config") and self.config: if hasattr(self, "config") and self.config:
@ -309,7 +303,6 @@ def example_usage():
# mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it) # mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it)
""" """
if "mamba" in cfg.model.arch_type: if "mamba" in cfg.model.arch_type:
cfg.model.prom_levels = 1
cfg.model.resp_levels = 1 cfg.model.resp_levels = 1
""" """
# cfg.model.loss_factors = {} # cfg.model.loss_factors = {}
@ -319,7 +312,7 @@ def example_usage():
def _load_quants(path) -> Tensor: def _load_quants(path) -> Tensor:
qnt = np.load(path, allow_pickle=True)[()] qnt = np.load(path, allow_pickle=True)[()]
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.prom_levels, :].t().to(torch.int16) return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.resp_levels, :].t().to(torch.int16)
qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}") qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")

View File

@ -19,6 +19,7 @@ tts = None
layout = {} layout = {}
layout["inference"] = {} layout["inference"] = {}
layout["training"] = {} layout["training"] = {}
layout["settings"] = {}
for k in layout.keys(): for k in layout.keys():
layout[k]["inputs"] = { "progress": None } layout[k]["inputs"] = { "progress": None }
@ -37,14 +38,42 @@ def gradio_wrapper(inputs):
return decorated return decorated
class timer: class timer:
def __enter__(self): def __init__(self, msg="Elapsed time:"):
self.start = perf_counter() self.msg = msg
return self
def __exit__(self, type, value, traceback): def __enter__(self):
print(f'[{datetime.now().isoformat()}] Elapsed time: {(perf_counter() - self.start):.3f}s') self.start = perf_counter()
return self
def init_tts(restart=False): def __exit__(self, type, value, traceback):
msg = f'{self.msg} {(perf_counter() - self.start):.3f}s'
gr.Info(msg)
print(f'[{datetime.now().isoformat()}] {msg}')
# returns a list of models, assuming the models are placed under ./training/ or ./models/
def get_model_paths( paths=[Path("./training/"), Path("./models/")] ):
yamls = []
for path in paths:
if not path.exists():
continue
for yaml in path.glob("**/*.yaml"):
if "/logs/" in str(yaml):
continue
yamls.append( yaml )
return yamls
#
def load_model( yaml ):
gr.Info(f"Loading: {yaml}")
init_tts( yaml=Path(yaml), restart=True )
gr.Info(f"Loaded model")
def init_tts(yaml=None, restart=False):
global tts global tts
if tts is not None: if tts is not None:
@ -53,13 +82,13 @@ def init_tts(restart=False):
del tts del tts
parser = argparse.ArgumentParser(allow_abbrev=False) parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', yaml)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--amp", action="store_true") parser.add_argument("--amp", action="store_true")
parser.add_argument("--dtype", type=str, default="auto") parser.add_argument("--dtype", type=str, default="auto")
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp ) tts = TTS( config=args.yaml if yaml is None else yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp )
return tts return tts
@gradio_wrapper(inputs=layout["inference"]["inputs"].keys()) @gradio_wrapper(inputs=layout["inference"]["inputs"].keys())
@ -78,7 +107,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--language", type=str, default="en") parser.add_argument("--language", type=str, default="en")
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"]) parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*cfg.dataset.frames_per_second)) parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*cfg.dataset.frames_per_second))
parser.add_argument("--max-nar-levels", type=int, default=kwargs["max-nar-levels"]) parser.add_argument("--max-nar-levels", type=int, default=0), # kwargs["max-nar-levels"])
parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"]) parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"])
parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"]) parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"])
parser.add_argument("--min-ar-temp", type=float, default=kwargs["min-ar-temp"]) parser.add_argument("--min-ar-temp", type=float, default=kwargs["min-ar-temp"])
@ -99,7 +128,9 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
raise ValueError("No reference audio provided.") raise ValueError("No reference audio provided.")
tts = init_tts() tts = init_tts()
with timer() as t:
gr.Info("Inferencing...")
with timer("Inferenced in") as t:
wav, sr = tts.inference( wav, sr = tts.inference(
text=args.text, text=args.text,
language=args.language, language=args.language,
@ -169,6 +200,7 @@ def get_random_prompt():
# setup args # setup args
parser = argparse.ArgumentParser(allow_abbrev=False) parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--listen", default=None, help="Path for Gradio to listen on") parser.add_argument("--listen", default=None, help="Path for Gradio to listen on")
parser.add_argument("--share", action="store_true") parser.add_argument("--share", action="store_true")
parser.add_argument("--render_markdown", action="store_true", default="VALLE_YAML" in os.environ) parser.add_argument("--render_markdown", action="store_true", default="VALLE_YAML" in os.environ)
@ -208,7 +240,7 @@ with ui:
with gr.Column(scale=7): with gr.Column(scale=7):
with gr.Row(): with gr.Row():
layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.") layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.") #layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.") layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
with gr.Row(): with gr.Row():
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)") layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
@ -249,6 +281,19 @@ with ui:
) )
""" """
with gr.Tab("Settings"):
with gr.Row():
with gr.Column(scale=7):
layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml, label="Model")
with gr.Column(scale=1):
layout["settings"]["buttons"]["load"] = gr.Button(value="Load Model")
layout["settings"]["buttons"]["load"].click(
fn=load_model,
inputs=[ x for x in layout["settings"]["inputs"].values() if x is not None],
outputs=[ x for x in layout["settings"]["outputs"].values() if x is not None],
)
if os.path.exists("README.md") and args.render_markdown: if os.path.exists("README.md") and args.render_markdown:
md = open("README.md", "r", encoding="utf-8").read() md = open("README.md", "r", encoding="utf-8").read()
# remove HF's metadata # remove HF's metadata