allow loading a different model within the web ui (apparently I did not have the web UI in the documentation)
This commit is contained in:
parent
7b210d9738
commit
3acc54df22
24
README.md
24
README.md
|
@ -20,7 +20,6 @@ Simply run `pip install git+https://git.ecker.tech/mrq/vall-e` or `pip install g
|
|||
|
||||
I've tested this repo under Python versions `3.10.9`, `3.11.3`, and `3.12.3`.
|
||||
|
||||
|
||||
## Pre-Trained Model
|
||||
|
||||
> [!NOTE]
|
||||
|
@ -196,11 +195,34 @@ And some experimental sampling flags you can use too (your mileage will ***defin
|
|||
+ **!**NOTE**!**: This is incompatible with beam search sampling (for the meantime at least).
|
||||
* `--mirostat-eta`: (AR only) the "learning rate" during mirostat sampling applied to the maximum surprise.
|
||||
|
||||
### Web UI
|
||||
|
||||
A Gradio-based web UI is accessible by running `python3 -m vall_e.webui`. You can, optionally, pass:
|
||||
|
||||
* `--yaml=./path/to/your/config.yaml`: will load the targeted YAML
|
||||
* `--listen 0.0.0.0:7860`: will set the web UI to listen to all IPs at port 7860. Replace the IP and Port to your preference.
|
||||
|
||||
#### Inference
|
||||
|
||||
Synthesizing speech is simple:
|
||||
|
||||
* `Input Prompt`: The guiding text prompt. Each new line will be it's own generated audio to be stitched together at the end.
|
||||
* `Audio Input`: The reference audio for the synthesis. Under Gradio, you can trim your clip accordingly, but leaving it as-is works fine.
|
||||
* `Output`: The resultant audio.
|
||||
* `Inference`: Button to start generating the audio.
|
||||
|
||||
All the additional knobs have a description that can be correlated to the above CLI flags.
|
||||
|
||||
#### Settings
|
||||
|
||||
So far, this only allows you to load a different model without needing to restart. The previous model should seamlessly unload, and the new one will load in place.
|
||||
|
||||
## To-Do
|
||||
|
||||
* [x] train and release a serviceable model for finetuning against.
|
||||
* [ ] train and release a ***good*** zero-shot model.
|
||||
- this should, hopefully, just simply requires another epoch or two for `ar+nar-llama-8`, as the foundation seems rather robust now.
|
||||
* [ ] well-integrated training through the Web UI (without the kludge from ai-voice-cloning)
|
||||
* [x] ~~explore alternative setups, like a NAR-only model~~
|
||||
- the current experiment of an AR length-predictor + NAR for the rest seems to fall apart...
|
||||
* [x] ~~explore better sampling techniques~~
|
||||
|
|
|
@ -210,11 +210,10 @@ class ModelExperimentalSettings:
|
|||
# I really need to clean this up
|
||||
@dataclass()
|
||||
class Model:
|
||||
name: str = "" # vanity name for the model
|
||||
name: str = "ar+nar" # vanity name for the model
|
||||
version: int = 5 # 1 = old with MultiEmbedding, 2 = new with AudioEmbedding, 3+ = additional embeddings
|
||||
size: str | dict = "full" # preset string or explicitly defined dimensionality
|
||||
resp_levels: int = 1 # RVQ-bin levels this model targets for outputs
|
||||
prom_levels: int = 8 # RVQ-bin levels this model accepts as an input prompt
|
||||
resp_levels: int = 8 # RVQ-bin levels this model supports
|
||||
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc") (unused)
|
||||
langs: int = 1 # defined languages (semi-unused)
|
||||
tones: int = 1 # defined tones (unsued)
|
||||
|
@ -238,7 +237,10 @@ class Model:
|
|||
|
||||
@property
|
||||
def max_levels(self):
|
||||
return max(self.prom_levels, self.resp_levels)
|
||||
# return RVQ level range
|
||||
if self.experimental is not None and self.experimental.rvq_level_range:
|
||||
return self.experimental.rvq_level_range[-1]
|
||||
return self.resp_levels
|
||||
|
||||
@property
|
||||
# required for fp8 as the lengths needs to be divisible by 8
|
||||
|
@ -626,7 +628,7 @@ class Inference:
|
|||
use_encodec: bool = True
|
||||
use_dac: bool = True
|
||||
|
||||
@cached_property
|
||||
@property
|
||||
def dtype(self):
|
||||
if self.weight_dtype == "float16":
|
||||
return torch.float16
|
||||
|
@ -651,7 +653,7 @@ class Optimizations:
|
|||
optimizers: bool = True # inject/replace optimizers (BnB, DAdaptation)
|
||||
|
||||
bitsandbytes: bool = False # use bitsandbytes
|
||||
dadaptation: bool = True # use dadaptation optimizer
|
||||
dadaptation: bool = False # use dadaptation optimizer
|
||||
bitnet: bool = False # use bitnet
|
||||
fp8: bool = False # use fp8
|
||||
|
||||
|
@ -671,7 +673,8 @@ class Config(BaseConfig):
|
|||
bitsandbytes: dict | list | None = None # deprecated
|
||||
optimizations: Optimizations = field(default_factory=lambda: Optimizations)
|
||||
|
||||
tokenizer: str = "./tokenizer.json"
|
||||
tokenizer: str | None = None
|
||||
tokenizer_path: str = "./tokenizer.json"
|
||||
|
||||
sample_rate: int = 24_000
|
||||
variable_sample_rate: bool = False # NOT recommended, as running directly 24Khz audio in the 44Khz DAC model will have detrimental quality loss
|
||||
|
@ -760,14 +763,21 @@ class Config(BaseConfig):
|
|||
self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
|
||||
self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ]
|
||||
|
||||
# do cleanup
|
||||
for model in self.models:
|
||||
if not isinstance( model, dict ):
|
||||
continue
|
||||
|
||||
if "prom_levels" in model:
|
||||
del model["prom_levels"]
|
||||
|
||||
if "interleave" in model:
|
||||
del model["interleave"]
|
||||
|
||||
if "audio_embedding_sums" not in model:
|
||||
continue
|
||||
|
||||
if not model["experimental"]:
|
||||
if "experimental" not in model or not model["experimental"]:
|
||||
model["experimental"] = {}
|
||||
|
||||
model["experimental"]["audio_embedding_sums"] = model.pop("audio_embedding_sums")
|
||||
|
@ -837,9 +847,9 @@ class Config(BaseConfig):
|
|||
try:
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
tokenizer_path = cfg.rel_path / cfg.tokenizer
|
||||
tokenizer_path = cfg.rel_path / cfg.tokenizer_path
|
||||
if not tokenizer_path.exists():
|
||||
tokenizer_path = Path("./data/") / cfg.tokenizer
|
||||
tokenizer_path = Path("./data/") / cfg.tokenizer_path
|
||||
cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path))
|
||||
except Exception as e:
|
||||
cfg.tokenizer = NaiveTokenizer()
|
||||
|
|
|
@ -961,8 +961,8 @@ class Dataset(_Dataset):
|
|||
"""
|
||||
|
||||
# trim to fit to requested prom/resps levels
|
||||
proms = proms[:, :cfg.model.prom_levels]
|
||||
resps = resps[:, :cfg.model.prom_levels]
|
||||
proms = proms[:, :cfg.model.resp_levels]
|
||||
resps = resps[:, :cfg.model.resp_levels]
|
||||
|
||||
|
||||
return dict(
|
||||
|
@ -1466,7 +1466,7 @@ if __name__ == "__main__":
|
|||
if task not in cfg.dataset.tasks_list:
|
||||
continue
|
||||
|
||||
print(text, task, cfg.model.prom_levels)
|
||||
print(text, task, cfg.model.resp_levels)
|
||||
print( proms.shape, resps.shape )
|
||||
|
||||
tokens = 0
|
||||
|
|
|
@ -37,7 +37,7 @@ def load_engines(training=True):
|
|||
optimizer = None
|
||||
lr_scheduler = None
|
||||
|
||||
inferencing = cfg.mode == "inferencing" or not model.config.training
|
||||
inferencing = cfg.mode == "inferencing" or not model.config.training or not training
|
||||
backend = cfg.inference.backend if inferencing else cfg.trainer.backend
|
||||
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
|
||||
amp = cfg.inference.amp if inferencing else cfg.trainer.amp
|
||||
|
|
|
@ -8,10 +8,10 @@ from einops import rearrange
|
|||
from pathlib import Path
|
||||
|
||||
from .emb import g2p, qnt
|
||||
from .emb.qnt import trim, trim_random
|
||||
from .emb.qnt import trim, trim_random, unload_model
|
||||
from .utils import to_device, set_seed, wrapper as ml
|
||||
|
||||
from .config import cfg
|
||||
from .config import cfg, Config
|
||||
from .models import get_models
|
||||
from .engines import load_engines, deepspeed_available
|
||||
from .data import get_phone_symmap, get_lang_symmap, _load_quants, _cleanup_phones, tokenize
|
||||
|
@ -23,10 +23,14 @@ class TTS():
|
|||
def __init__( self, config=None, device=None, amp=None, dtype=None ):
|
||||
self.loading = True
|
||||
|
||||
self.input_sample_rate = 24000
|
||||
self.output_sample_rate = 24000
|
||||
self.load_config( config=config, device=device, amp=amp, dtype=dtype )
|
||||
self.load_model()
|
||||
|
||||
self.loading = False
|
||||
|
||||
def load_config( self, config=None, device=None, amp=None, dtype=None ):
|
||||
if config:
|
||||
print("Loading YAML:", config)
|
||||
cfg.load_yaml( config )
|
||||
|
||||
try:
|
||||
|
@ -53,7 +57,10 @@ class TTS():
|
|||
self.dtype = cfg.inference.dtype
|
||||
self.amp = amp
|
||||
|
||||
self.symmap = None
|
||||
|
||||
def load_model( self ):
|
||||
load_engines.cache_clear()
|
||||
unload_model()
|
||||
|
||||
self.engines = load_engines(training=False)
|
||||
for name, engine in self.engines.items():
|
||||
|
@ -61,11 +68,8 @@ class TTS():
|
|||
engine.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||
|
||||
self.engines.eval()
|
||||
|
||||
if self.symmap is None:
|
||||
self.symmap = get_phone_symmap()
|
||||
|
||||
self.loading = False
|
||||
print("Loaded model")
|
||||
|
||||
def encode_text( self, text, language="en" ):
|
||||
# already a tensor, return it
|
||||
|
|
|
@ -33,12 +33,6 @@ class AR_NAR(Base):
|
|||
def causal(self):
|
||||
return "ar" in self.capabilities
|
||||
|
||||
@property
|
||||
def n_prom_levels(self) -> int:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.prom_levels
|
||||
return cfg.model.prom_levels
|
||||
|
||||
@property
|
||||
def n_resp_levels(self) -> int:
|
||||
if hasattr(self, "config") and self.config:
|
||||
|
@ -197,7 +191,7 @@ class AR_NAR(Base):
|
|||
|
||||
# is NAR
|
||||
if max_levels == 0:
|
||||
max_levels = self.n_resp_levels - 1
|
||||
max_levels = self.n_max_levels - 1
|
||||
|
||||
# expand if given a raw 1D tensor
|
||||
for i, resp in enumerate(resps_list):
|
||||
|
@ -373,7 +367,6 @@ def example_usage():
|
|||
# mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it)
|
||||
"""
|
||||
if "mamba" in cfg.model.arch_type:
|
||||
cfg.model.prom_levels = 1
|
||||
cfg.model.resp_levels = 1
|
||||
"""
|
||||
# cfg.model.loss_factors = {}
|
||||
|
@ -383,7 +376,7 @@ def example_usage():
|
|||
|
||||
def _load_quants(path) -> Tensor:
|
||||
qnt = np.load(path, allow_pickle=True)[()]
|
||||
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.prom_levels, :].t().to(torch.int16)
|
||||
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.resp_levels, :].t().to(torch.int16)
|
||||
|
||||
qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
|
||||
|
||||
|
|
|
@ -123,7 +123,7 @@ class AudioEmbedding_Old(nn.Module):
|
|||
):
|
||||
super().__init__()
|
||||
# array of embeddings
|
||||
# proms are [0, prom_levels]
|
||||
# proms are [0, resp_levels]
|
||||
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
|
||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
|
||||
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
|
||||
|
@ -154,7 +154,7 @@ class AudioEmbedding(nn.Module):
|
|||
):
|
||||
super().__init__()
|
||||
# array of embeddings
|
||||
# proms are [0, prom_levels]
|
||||
# proms are [0, resp_levels]
|
||||
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
|
||||
# + resps cannot share the AR and NAR embeddings, since they do encode whether to predict the same level but in the next token or predict in place but the next level
|
||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
|
||||
|
@ -282,10 +282,6 @@ class Base(nn.Module):
|
|||
def causal(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def n_prom_levels(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def n_resp_levels(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
@ -368,7 +364,6 @@ class Base(nn.Module):
|
|||
|
||||
self.l_padding = l_padding
|
||||
|
||||
n_prom_tokens = n_audio_tokens
|
||||
arch_type = self.config.arch_type if self.config is not None else "llama"
|
||||
|
||||
self.arch_type = arch_type
|
||||
|
@ -397,14 +392,14 @@ class Base(nn.Module):
|
|||
self.len_emb = None
|
||||
|
||||
if self.version == 1: # legacy
|
||||
n_prom_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom
|
||||
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
||||
n_audio_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom
|
||||
self.proms_emb = MultiEmbedding(self.n_resp_levels, n_audio_tokens, d_model)
|
||||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
|
||||
elif self.version < 5:
|
||||
# [1024] * 8
|
||||
self.proms_emb = AudioEmbedding_Old(
|
||||
[n_prom_tokens] * self.n_prom_levels, d_model,
|
||||
levels=self.n_prom_levels if self.version > 3 else None,
|
||||
[n_audio_tokens] * self.n_resp_levels, d_model,
|
||||
levels=self.n_resp_levels if self.version > 3 else None,
|
||||
)
|
||||
# [1024 + STOP] + [1024] * 8
|
||||
self.resps_emb = AudioEmbedding_Old(
|
||||
|
@ -413,7 +408,7 @@ class Base(nn.Module):
|
|||
)
|
||||
else:
|
||||
self.proms_emb = AudioEmbedding(
|
||||
[n_prom_tokens] * self.n_prom_levels, d_model,
|
||||
[n_audio_tokens] * self.n_resp_levels, d_model,
|
||||
sums=audio_embedding_sums,
|
||||
external_mode=audio_embedding_mode,
|
||||
)
|
||||
|
|
|
@ -31,12 +31,6 @@ class NAR(Base):
|
|||
def causal(self):
|
||||
return "len" in self.capabilities
|
||||
|
||||
@property
|
||||
def n_prom_levels(self) -> int:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.prom_levels
|
||||
return cfg.model.prom_levels
|
||||
|
||||
@property
|
||||
def n_resp_levels(self) -> int:
|
||||
if hasattr(self, "config") and self.config:
|
||||
|
@ -309,7 +303,6 @@ def example_usage():
|
|||
# mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it)
|
||||
"""
|
||||
if "mamba" in cfg.model.arch_type:
|
||||
cfg.model.prom_levels = 1
|
||||
cfg.model.resp_levels = 1
|
||||
"""
|
||||
# cfg.model.loss_factors = {}
|
||||
|
@ -319,7 +312,7 @@ def example_usage():
|
|||
|
||||
def _load_quants(path) -> Tensor:
|
||||
qnt = np.load(path, allow_pickle=True)[()]
|
||||
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.prom_levels, :].t().to(torch.int16)
|
||||
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.resp_levels, :].t().to(torch.int16)
|
||||
|
||||
qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ tts = None
|
|||
layout = {}
|
||||
layout["inference"] = {}
|
||||
layout["training"] = {}
|
||||
layout["settings"] = {}
|
||||
|
||||
for k in layout.keys():
|
||||
layout[k]["inputs"] = { "progress": None }
|
||||
|
@ -37,14 +38,42 @@ def gradio_wrapper(inputs):
|
|||
return decorated
|
||||
|
||||
class timer:
|
||||
def __init__(self, msg="Elapsed time:"):
|
||||
self.msg = msg
|
||||
|
||||
def __enter__(self):
|
||||
self.start = perf_counter()
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
print(f'[{datetime.now().isoformat()}] Elapsed time: {(perf_counter() - self.start):.3f}s')
|
||||
msg = f'{self.msg} {(perf_counter() - self.start):.3f}s'
|
||||
|
||||
def init_tts(restart=False):
|
||||
gr.Info(msg)
|
||||
print(f'[{datetime.now().isoformat()}] {msg}')
|
||||
|
||||
# returns a list of models, assuming the models are placed under ./training/ or ./models/
|
||||
def get_model_paths( paths=[Path("./training/"), Path("./models/")] ):
|
||||
yamls = []
|
||||
|
||||
for path in paths:
|
||||
if not path.exists():
|
||||
continue
|
||||
|
||||
for yaml in path.glob("**/*.yaml"):
|
||||
if "/logs/" in str(yaml):
|
||||
continue
|
||||
|
||||
yamls.append( yaml )
|
||||
|
||||
return yamls
|
||||
|
||||
#
|
||||
def load_model( yaml ):
|
||||
gr.Info(f"Loading: {yaml}")
|
||||
init_tts( yaml=Path(yaml), restart=True )
|
||||
gr.Info(f"Loaded model")
|
||||
|
||||
def init_tts(yaml=None, restart=False):
|
||||
global tts
|
||||
|
||||
if tts is not None:
|
||||
|
@ -53,13 +82,13 @@ def init_tts(restart=False):
|
|||
del tts
|
||||
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', yaml)) # os environ so it can be specified in a HuggingFace Space too
|
||||
parser.add_argument("--device", type=str, default="cuda")
|
||||
parser.add_argument("--amp", action="store_true")
|
||||
parser.add_argument("--dtype", type=str, default="auto")
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp )
|
||||
tts = TTS( config=args.yaml if yaml is None else yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp )
|
||||
return tts
|
||||
|
||||
@gradio_wrapper(inputs=layout["inference"]["inputs"].keys())
|
||||
|
@ -78,7 +107,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
parser.add_argument("--language", type=str, default="en")
|
||||
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
|
||||
parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*cfg.dataset.frames_per_second))
|
||||
parser.add_argument("--max-nar-levels", type=int, default=kwargs["max-nar-levels"])
|
||||
parser.add_argument("--max-nar-levels", type=int, default=0), # kwargs["max-nar-levels"])
|
||||
parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"])
|
||||
parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"])
|
||||
parser.add_argument("--min-ar-temp", type=float, default=kwargs["min-ar-temp"])
|
||||
|
@ -99,7 +128,9 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
raise ValueError("No reference audio provided.")
|
||||
|
||||
tts = init_tts()
|
||||
with timer() as t:
|
||||
|
||||
gr.Info("Inferencing...")
|
||||
with timer("Inferenced in") as t:
|
||||
wav, sr = tts.inference(
|
||||
text=args.text,
|
||||
language=args.language,
|
||||
|
@ -169,6 +200,7 @@ def get_random_prompt():
|
|||
|
||||
# setup args
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||
parser.add_argument("--listen", default=None, help="Path for Gradio to listen on")
|
||||
parser.add_argument("--share", action="store_true")
|
||||
parser.add_argument("--render_markdown", action="store_true", default="VALLE_YAML" in os.environ)
|
||||
|
@ -208,7 +240,7 @@ with ui:
|
|||
with gr.Column(scale=7):
|
||||
with gr.Row():
|
||||
layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
|
||||
layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
|
||||
#layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
|
||||
layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
|
||||
with gr.Row():
|
||||
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
|
||||
|
@ -249,6 +281,19 @@ with ui:
|
|||
)
|
||||
"""
|
||||
|
||||
with gr.Tab("Settings"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=7):
|
||||
layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml, label="Model")
|
||||
with gr.Column(scale=1):
|
||||
layout["settings"]["buttons"]["load"] = gr.Button(value="Load Model")
|
||||
|
||||
layout["settings"]["buttons"]["load"].click(
|
||||
fn=load_model,
|
||||
inputs=[ x for x in layout["settings"]["inputs"].values() if x is not None],
|
||||
outputs=[ x for x in layout["settings"]["outputs"].values() if x is not None],
|
||||
)
|
||||
|
||||
if os.path.exists("README.md") and args.render_markdown:
|
||||
md = open("README.md", "r", encoding="utf-8").read()
|
||||
# remove HF's metadata
|
||||
|
|
Loading…
Reference in New Issue
Block a user