allow loading a different model within the web ui (apparently I did not have the web UI in the documentation)
This commit is contained in:
parent
7b210d9738
commit
3acc54df22
24
README.md
24
README.md
|
@ -20,7 +20,6 @@ Simply run `pip install git+https://git.ecker.tech/mrq/vall-e` or `pip install g
|
||||||
|
|
||||||
I've tested this repo under Python versions `3.10.9`, `3.11.3`, and `3.12.3`.
|
I've tested this repo under Python versions `3.10.9`, `3.11.3`, and `3.12.3`.
|
||||||
|
|
||||||
|
|
||||||
## Pre-Trained Model
|
## Pre-Trained Model
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
|
@ -196,11 +195,34 @@ And some experimental sampling flags you can use too (your mileage will ***defin
|
||||||
+ **!**NOTE**!**: This is incompatible with beam search sampling (for the meantime at least).
|
+ **!**NOTE**!**: This is incompatible with beam search sampling (for the meantime at least).
|
||||||
* `--mirostat-eta`: (AR only) the "learning rate" during mirostat sampling applied to the maximum surprise.
|
* `--mirostat-eta`: (AR only) the "learning rate" during mirostat sampling applied to the maximum surprise.
|
||||||
|
|
||||||
|
### Web UI
|
||||||
|
|
||||||
|
A Gradio-based web UI is accessible by running `python3 -m vall_e.webui`. You can, optionally, pass:
|
||||||
|
|
||||||
|
* `--yaml=./path/to/your/config.yaml`: will load the targeted YAML
|
||||||
|
* `--listen 0.0.0.0:7860`: will set the web UI to listen to all IPs at port 7860. Replace the IP and Port to your preference.
|
||||||
|
|
||||||
|
#### Inference
|
||||||
|
|
||||||
|
Synthesizing speech is simple:
|
||||||
|
|
||||||
|
* `Input Prompt`: The guiding text prompt. Each new line will be it's own generated audio to be stitched together at the end.
|
||||||
|
* `Audio Input`: The reference audio for the synthesis. Under Gradio, you can trim your clip accordingly, but leaving it as-is works fine.
|
||||||
|
* `Output`: The resultant audio.
|
||||||
|
* `Inference`: Button to start generating the audio.
|
||||||
|
|
||||||
|
All the additional knobs have a description that can be correlated to the above CLI flags.
|
||||||
|
|
||||||
|
#### Settings
|
||||||
|
|
||||||
|
So far, this only allows you to load a different model without needing to restart. The previous model should seamlessly unload, and the new one will load in place.
|
||||||
|
|
||||||
## To-Do
|
## To-Do
|
||||||
|
|
||||||
* [x] train and release a serviceable model for finetuning against.
|
* [x] train and release a serviceable model for finetuning against.
|
||||||
* [ ] train and release a ***good*** zero-shot model.
|
* [ ] train and release a ***good*** zero-shot model.
|
||||||
- this should, hopefully, just simply requires another epoch or two for `ar+nar-llama-8`, as the foundation seems rather robust now.
|
- this should, hopefully, just simply requires another epoch or two for `ar+nar-llama-8`, as the foundation seems rather robust now.
|
||||||
|
* [ ] well-integrated training through the Web UI (without the kludge from ai-voice-cloning)
|
||||||
* [x] ~~explore alternative setups, like a NAR-only model~~
|
* [x] ~~explore alternative setups, like a NAR-only model~~
|
||||||
- the current experiment of an AR length-predictor + NAR for the rest seems to fall apart...
|
- the current experiment of an AR length-predictor + NAR for the rest seems to fall apart...
|
||||||
* [x] ~~explore better sampling techniques~~
|
* [x] ~~explore better sampling techniques~~
|
||||||
|
|
|
@ -210,11 +210,10 @@ class ModelExperimentalSettings:
|
||||||
# I really need to clean this up
|
# I really need to clean this up
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Model:
|
class Model:
|
||||||
name: str = "" # vanity name for the model
|
name: str = "ar+nar" # vanity name for the model
|
||||||
version: int = 5 # 1 = old with MultiEmbedding, 2 = new with AudioEmbedding, 3+ = additional embeddings
|
version: int = 5 # 1 = old with MultiEmbedding, 2 = new with AudioEmbedding, 3+ = additional embeddings
|
||||||
size: str | dict = "full" # preset string or explicitly defined dimensionality
|
size: str | dict = "full" # preset string or explicitly defined dimensionality
|
||||||
resp_levels: int = 1 # RVQ-bin levels this model targets for outputs
|
resp_levels: int = 8 # RVQ-bin levels this model supports
|
||||||
prom_levels: int = 8 # RVQ-bin levels this model accepts as an input prompt
|
|
||||||
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc") (unused)
|
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc") (unused)
|
||||||
langs: int = 1 # defined languages (semi-unused)
|
langs: int = 1 # defined languages (semi-unused)
|
||||||
tones: int = 1 # defined tones (unsued)
|
tones: int = 1 # defined tones (unsued)
|
||||||
|
@ -238,7 +237,10 @@ class Model:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_levels(self):
|
def max_levels(self):
|
||||||
return max(self.prom_levels, self.resp_levels)
|
# return RVQ level range
|
||||||
|
if self.experimental is not None and self.experimental.rvq_level_range:
|
||||||
|
return self.experimental.rvq_level_range[-1]
|
||||||
|
return self.resp_levels
|
||||||
|
|
||||||
@property
|
@property
|
||||||
# required for fp8 as the lengths needs to be divisible by 8
|
# required for fp8 as the lengths needs to be divisible by 8
|
||||||
|
@ -626,7 +628,7 @@ class Inference:
|
||||||
use_encodec: bool = True
|
use_encodec: bool = True
|
||||||
use_dac: bool = True
|
use_dac: bool = True
|
||||||
|
|
||||||
@cached_property
|
@property
|
||||||
def dtype(self):
|
def dtype(self):
|
||||||
if self.weight_dtype == "float16":
|
if self.weight_dtype == "float16":
|
||||||
return torch.float16
|
return torch.float16
|
||||||
|
@ -651,7 +653,7 @@ class Optimizations:
|
||||||
optimizers: bool = True # inject/replace optimizers (BnB, DAdaptation)
|
optimizers: bool = True # inject/replace optimizers (BnB, DAdaptation)
|
||||||
|
|
||||||
bitsandbytes: bool = False # use bitsandbytes
|
bitsandbytes: bool = False # use bitsandbytes
|
||||||
dadaptation: bool = True # use dadaptation optimizer
|
dadaptation: bool = False # use dadaptation optimizer
|
||||||
bitnet: bool = False # use bitnet
|
bitnet: bool = False # use bitnet
|
||||||
fp8: bool = False # use fp8
|
fp8: bool = False # use fp8
|
||||||
|
|
||||||
|
@ -671,7 +673,8 @@ class Config(BaseConfig):
|
||||||
bitsandbytes: dict | list | None = None # deprecated
|
bitsandbytes: dict | list | None = None # deprecated
|
||||||
optimizations: Optimizations = field(default_factory=lambda: Optimizations)
|
optimizations: Optimizations = field(default_factory=lambda: Optimizations)
|
||||||
|
|
||||||
tokenizer: str = "./tokenizer.json"
|
tokenizer: str | None = None
|
||||||
|
tokenizer_path: str = "./tokenizer.json"
|
||||||
|
|
||||||
sample_rate: int = 24_000
|
sample_rate: int = 24_000
|
||||||
variable_sample_rate: bool = False # NOT recommended, as running directly 24Khz audio in the 44Khz DAC model will have detrimental quality loss
|
variable_sample_rate: bool = False # NOT recommended, as running directly 24Khz audio in the 44Khz DAC model will have detrimental quality loss
|
||||||
|
@ -760,14 +763,21 @@ class Config(BaseConfig):
|
||||||
self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
|
self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
|
||||||
self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ]
|
self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ]
|
||||||
|
|
||||||
|
# do cleanup
|
||||||
for model in self.models:
|
for model in self.models:
|
||||||
if not isinstance( model, dict ):
|
if not isinstance( model, dict ):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if "prom_levels" in model:
|
||||||
|
del model["prom_levels"]
|
||||||
|
|
||||||
|
if "interleave" in model:
|
||||||
|
del model["interleave"]
|
||||||
|
|
||||||
if "audio_embedding_sums" not in model:
|
if "audio_embedding_sums" not in model:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not model["experimental"]:
|
if "experimental" not in model or not model["experimental"]:
|
||||||
model["experimental"] = {}
|
model["experimental"] = {}
|
||||||
|
|
||||||
model["experimental"]["audio_embedding_sums"] = model.pop("audio_embedding_sums")
|
model["experimental"]["audio_embedding_sums"] = model.pop("audio_embedding_sums")
|
||||||
|
@ -837,9 +847,9 @@ class Config(BaseConfig):
|
||||||
try:
|
try:
|
||||||
from transformers import PreTrainedTokenizerFast
|
from transformers import PreTrainedTokenizerFast
|
||||||
|
|
||||||
tokenizer_path = cfg.rel_path / cfg.tokenizer
|
tokenizer_path = cfg.rel_path / cfg.tokenizer_path
|
||||||
if not tokenizer_path.exists():
|
if not tokenizer_path.exists():
|
||||||
tokenizer_path = Path("./data/") / cfg.tokenizer
|
tokenizer_path = Path("./data/") / cfg.tokenizer_path
|
||||||
cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path))
|
cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
cfg.tokenizer = NaiveTokenizer()
|
cfg.tokenizer = NaiveTokenizer()
|
||||||
|
|
|
@ -961,8 +961,8 @@ class Dataset(_Dataset):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# trim to fit to requested prom/resps levels
|
# trim to fit to requested prom/resps levels
|
||||||
proms = proms[:, :cfg.model.prom_levels]
|
proms = proms[:, :cfg.model.resp_levels]
|
||||||
resps = resps[:, :cfg.model.prom_levels]
|
resps = resps[:, :cfg.model.resp_levels]
|
||||||
|
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
|
@ -1466,7 +1466,7 @@ if __name__ == "__main__":
|
||||||
if task not in cfg.dataset.tasks_list:
|
if task not in cfg.dataset.tasks_list:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
print(text, task, cfg.model.prom_levels)
|
print(text, task, cfg.model.resp_levels)
|
||||||
print( proms.shape, resps.shape )
|
print( proms.shape, resps.shape )
|
||||||
|
|
||||||
tokens = 0
|
tokens = 0
|
||||||
|
|
|
@ -37,7 +37,7 @@ def load_engines(training=True):
|
||||||
optimizer = None
|
optimizer = None
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
|
|
||||||
inferencing = cfg.mode == "inferencing" or not model.config.training
|
inferencing = cfg.mode == "inferencing" or not model.config.training or not training
|
||||||
backend = cfg.inference.backend if inferencing else cfg.trainer.backend
|
backend = cfg.inference.backend if inferencing else cfg.trainer.backend
|
||||||
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
|
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
|
||||||
amp = cfg.inference.amp if inferencing else cfg.trainer.amp
|
amp = cfg.inference.amp if inferencing else cfg.trainer.amp
|
||||||
|
|
|
@ -8,10 +8,10 @@ from einops import rearrange
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from .emb import g2p, qnt
|
from .emb import g2p, qnt
|
||||||
from .emb.qnt import trim, trim_random
|
from .emb.qnt import trim, trim_random, unload_model
|
||||||
from .utils import to_device, set_seed, wrapper as ml
|
from .utils import to_device, set_seed, wrapper as ml
|
||||||
|
|
||||||
from .config import cfg
|
from .config import cfg, Config
|
||||||
from .models import get_models
|
from .models import get_models
|
||||||
from .engines import load_engines, deepspeed_available
|
from .engines import load_engines, deepspeed_available
|
||||||
from .data import get_phone_symmap, get_lang_symmap, _load_quants, _cleanup_phones, tokenize
|
from .data import get_phone_symmap, get_lang_symmap, _load_quants, _cleanup_phones, tokenize
|
||||||
|
@ -23,10 +23,14 @@ class TTS():
|
||||||
def __init__( self, config=None, device=None, amp=None, dtype=None ):
|
def __init__( self, config=None, device=None, amp=None, dtype=None ):
|
||||||
self.loading = True
|
self.loading = True
|
||||||
|
|
||||||
self.input_sample_rate = 24000
|
self.load_config( config=config, device=device, amp=amp, dtype=dtype )
|
||||||
self.output_sample_rate = 24000
|
self.load_model()
|
||||||
|
|
||||||
|
self.loading = False
|
||||||
|
|
||||||
|
def load_config( self, config=None, device=None, amp=None, dtype=None ):
|
||||||
if config:
|
if config:
|
||||||
|
print("Loading YAML:", config)
|
||||||
cfg.load_yaml( config )
|
cfg.load_yaml( config )
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -53,7 +57,10 @@ class TTS():
|
||||||
self.dtype = cfg.inference.dtype
|
self.dtype = cfg.inference.dtype
|
||||||
self.amp = amp
|
self.amp = amp
|
||||||
|
|
||||||
self.symmap = None
|
|
||||||
|
def load_model( self ):
|
||||||
|
load_engines.cache_clear()
|
||||||
|
unload_model()
|
||||||
|
|
||||||
self.engines = load_engines(training=False)
|
self.engines = load_engines(training=False)
|
||||||
for name, engine in self.engines.items():
|
for name, engine in self.engines.items():
|
||||||
|
@ -61,11 +68,8 @@ class TTS():
|
||||||
engine.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
engine.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||||
|
|
||||||
self.engines.eval()
|
self.engines.eval()
|
||||||
|
self.symmap = get_phone_symmap()
|
||||||
if self.symmap is None:
|
print("Loaded model")
|
||||||
self.symmap = get_phone_symmap()
|
|
||||||
|
|
||||||
self.loading = False
|
|
||||||
|
|
||||||
def encode_text( self, text, language="en" ):
|
def encode_text( self, text, language="en" ):
|
||||||
# already a tensor, return it
|
# already a tensor, return it
|
||||||
|
|
|
@ -33,12 +33,6 @@ class AR_NAR(Base):
|
||||||
def causal(self):
|
def causal(self):
|
||||||
return "ar" in self.capabilities
|
return "ar" in self.capabilities
|
||||||
|
|
||||||
@property
|
|
||||||
def n_prom_levels(self) -> int:
|
|
||||||
if hasattr(self, "config") and self.config:
|
|
||||||
return self.config.prom_levels
|
|
||||||
return cfg.model.prom_levels
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def n_resp_levels(self) -> int:
|
def n_resp_levels(self) -> int:
|
||||||
if hasattr(self, "config") and self.config:
|
if hasattr(self, "config") and self.config:
|
||||||
|
@ -197,7 +191,7 @@ class AR_NAR(Base):
|
||||||
|
|
||||||
# is NAR
|
# is NAR
|
||||||
if max_levels == 0:
|
if max_levels == 0:
|
||||||
max_levels = self.n_resp_levels - 1
|
max_levels = self.n_max_levels - 1
|
||||||
|
|
||||||
# expand if given a raw 1D tensor
|
# expand if given a raw 1D tensor
|
||||||
for i, resp in enumerate(resps_list):
|
for i, resp in enumerate(resps_list):
|
||||||
|
@ -373,7 +367,6 @@ def example_usage():
|
||||||
# mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it)
|
# mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it)
|
||||||
"""
|
"""
|
||||||
if "mamba" in cfg.model.arch_type:
|
if "mamba" in cfg.model.arch_type:
|
||||||
cfg.model.prom_levels = 1
|
|
||||||
cfg.model.resp_levels = 1
|
cfg.model.resp_levels = 1
|
||||||
"""
|
"""
|
||||||
# cfg.model.loss_factors = {}
|
# cfg.model.loss_factors = {}
|
||||||
|
@ -383,7 +376,7 @@ def example_usage():
|
||||||
|
|
||||||
def _load_quants(path) -> Tensor:
|
def _load_quants(path) -> Tensor:
|
||||||
qnt = np.load(path, allow_pickle=True)[()]
|
qnt = np.load(path, allow_pickle=True)[()]
|
||||||
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.prom_levels, :].t().to(torch.int16)
|
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.resp_levels, :].t().to(torch.int16)
|
||||||
|
|
||||||
qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
|
qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
|
||||||
|
|
||||||
|
|
|
@ -123,7 +123,7 @@ class AudioEmbedding_Old(nn.Module):
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# array of embeddings
|
# array of embeddings
|
||||||
# proms are [0, prom_levels]
|
# proms are [0, resp_levels]
|
||||||
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
|
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
|
||||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
|
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
|
||||||
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
|
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
|
||||||
|
@ -154,7 +154,7 @@ class AudioEmbedding(nn.Module):
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# array of embeddings
|
# array of embeddings
|
||||||
# proms are [0, prom_levels]
|
# proms are [0, resp_levels]
|
||||||
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
|
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
|
||||||
# + resps cannot share the AR and NAR embeddings, since they do encode whether to predict the same level but in the next token or predict in place but the next level
|
# + resps cannot share the AR and NAR embeddings, since they do encode whether to predict the same level but in the next token or predict in place but the next level
|
||||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
|
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
|
||||||
|
@ -282,10 +282,6 @@ class Base(nn.Module):
|
||||||
def causal(self) -> bool:
|
def causal(self) -> bool:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@property
|
|
||||||
def n_prom_levels(self) -> int:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def n_resp_levels(self) -> int:
|
def n_resp_levels(self) -> int:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -368,7 +364,6 @@ class Base(nn.Module):
|
||||||
|
|
||||||
self.l_padding = l_padding
|
self.l_padding = l_padding
|
||||||
|
|
||||||
n_prom_tokens = n_audio_tokens
|
|
||||||
arch_type = self.config.arch_type if self.config is not None else "llama"
|
arch_type = self.config.arch_type if self.config is not None else "llama"
|
||||||
|
|
||||||
self.arch_type = arch_type
|
self.arch_type = arch_type
|
||||||
|
@ -397,14 +392,14 @@ class Base(nn.Module):
|
||||||
self.len_emb = None
|
self.len_emb = None
|
||||||
|
|
||||||
if self.version == 1: # legacy
|
if self.version == 1: # legacy
|
||||||
n_prom_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom
|
n_audio_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom
|
||||||
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
self.proms_emb = MultiEmbedding(self.n_resp_levels, n_audio_tokens, d_model)
|
||||||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
|
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
|
||||||
elif self.version < 5:
|
elif self.version < 5:
|
||||||
# [1024] * 8
|
# [1024] * 8
|
||||||
self.proms_emb = AudioEmbedding_Old(
|
self.proms_emb = AudioEmbedding_Old(
|
||||||
[n_prom_tokens] * self.n_prom_levels, d_model,
|
[n_audio_tokens] * self.n_resp_levels, d_model,
|
||||||
levels=self.n_prom_levels if self.version > 3 else None,
|
levels=self.n_resp_levels if self.version > 3 else None,
|
||||||
)
|
)
|
||||||
# [1024 + STOP] + [1024] * 8
|
# [1024 + STOP] + [1024] * 8
|
||||||
self.resps_emb = AudioEmbedding_Old(
|
self.resps_emb = AudioEmbedding_Old(
|
||||||
|
@ -413,7 +408,7 @@ class Base(nn.Module):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.proms_emb = AudioEmbedding(
|
self.proms_emb = AudioEmbedding(
|
||||||
[n_prom_tokens] * self.n_prom_levels, d_model,
|
[n_audio_tokens] * self.n_resp_levels, d_model,
|
||||||
sums=audio_embedding_sums,
|
sums=audio_embedding_sums,
|
||||||
external_mode=audio_embedding_mode,
|
external_mode=audio_embedding_mode,
|
||||||
)
|
)
|
||||||
|
|
|
@ -31,12 +31,6 @@ class NAR(Base):
|
||||||
def causal(self):
|
def causal(self):
|
||||||
return "len" in self.capabilities
|
return "len" in self.capabilities
|
||||||
|
|
||||||
@property
|
|
||||||
def n_prom_levels(self) -> int:
|
|
||||||
if hasattr(self, "config") and self.config:
|
|
||||||
return self.config.prom_levels
|
|
||||||
return cfg.model.prom_levels
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def n_resp_levels(self) -> int:
|
def n_resp_levels(self) -> int:
|
||||||
if hasattr(self, "config") and self.config:
|
if hasattr(self, "config") and self.config:
|
||||||
|
@ -309,7 +303,6 @@ def example_usage():
|
||||||
# mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it)
|
# mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it)
|
||||||
"""
|
"""
|
||||||
if "mamba" in cfg.model.arch_type:
|
if "mamba" in cfg.model.arch_type:
|
||||||
cfg.model.prom_levels = 1
|
|
||||||
cfg.model.resp_levels = 1
|
cfg.model.resp_levels = 1
|
||||||
"""
|
"""
|
||||||
# cfg.model.loss_factors = {}
|
# cfg.model.loss_factors = {}
|
||||||
|
@ -319,7 +312,7 @@ def example_usage():
|
||||||
|
|
||||||
def _load_quants(path) -> Tensor:
|
def _load_quants(path) -> Tensor:
|
||||||
qnt = np.load(path, allow_pickle=True)[()]
|
qnt = np.load(path, allow_pickle=True)[()]
|
||||||
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.prom_levels, :].t().to(torch.int16)
|
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.resp_levels, :].t().to(torch.int16)
|
||||||
|
|
||||||
qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
|
qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@ tts = None
|
||||||
layout = {}
|
layout = {}
|
||||||
layout["inference"] = {}
|
layout["inference"] = {}
|
||||||
layout["training"] = {}
|
layout["training"] = {}
|
||||||
|
layout["settings"] = {}
|
||||||
|
|
||||||
for k in layout.keys():
|
for k in layout.keys():
|
||||||
layout[k]["inputs"] = { "progress": None }
|
layout[k]["inputs"] = { "progress": None }
|
||||||
|
@ -37,14 +38,42 @@ def gradio_wrapper(inputs):
|
||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
class timer:
|
class timer:
|
||||||
def __enter__(self):
|
def __init__(self, msg="Elapsed time:"):
|
||||||
self.start = perf_counter()
|
self.msg = msg
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, type, value, traceback):
|
def __enter__(self):
|
||||||
print(f'[{datetime.now().isoformat()}] Elapsed time: {(perf_counter() - self.start):.3f}s')
|
self.start = perf_counter()
|
||||||
|
return self
|
||||||
|
|
||||||
def init_tts(restart=False):
|
def __exit__(self, type, value, traceback):
|
||||||
|
msg = f'{self.msg} {(perf_counter() - self.start):.3f}s'
|
||||||
|
|
||||||
|
gr.Info(msg)
|
||||||
|
print(f'[{datetime.now().isoformat()}] {msg}')
|
||||||
|
|
||||||
|
# returns a list of models, assuming the models are placed under ./training/ or ./models/
|
||||||
|
def get_model_paths( paths=[Path("./training/"), Path("./models/")] ):
|
||||||
|
yamls = []
|
||||||
|
|
||||||
|
for path in paths:
|
||||||
|
if not path.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
for yaml in path.glob("**/*.yaml"):
|
||||||
|
if "/logs/" in str(yaml):
|
||||||
|
continue
|
||||||
|
|
||||||
|
yamls.append( yaml )
|
||||||
|
|
||||||
|
return yamls
|
||||||
|
|
||||||
|
#
|
||||||
|
def load_model( yaml ):
|
||||||
|
gr.Info(f"Loading: {yaml}")
|
||||||
|
init_tts( yaml=Path(yaml), restart=True )
|
||||||
|
gr.Info(f"Loaded model")
|
||||||
|
|
||||||
|
def init_tts(yaml=None, restart=False):
|
||||||
global tts
|
global tts
|
||||||
|
|
||||||
if tts is not None:
|
if tts is not None:
|
||||||
|
@ -53,13 +82,13 @@ def init_tts(restart=False):
|
||||||
del tts
|
del tts
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||||
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
|
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', yaml)) # os environ so it can be specified in a HuggingFace Space too
|
||||||
parser.add_argument("--device", type=str, default="cuda")
|
parser.add_argument("--device", type=str, default="cuda")
|
||||||
parser.add_argument("--amp", action="store_true")
|
parser.add_argument("--amp", action="store_true")
|
||||||
parser.add_argument("--dtype", type=str, default="auto")
|
parser.add_argument("--dtype", type=str, default="auto")
|
||||||
args, unknown = parser.parse_known_args()
|
args, unknown = parser.parse_known_args()
|
||||||
|
|
||||||
tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp )
|
tts = TTS( config=args.yaml if yaml is None else yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp )
|
||||||
return tts
|
return tts
|
||||||
|
|
||||||
@gradio_wrapper(inputs=layout["inference"]["inputs"].keys())
|
@gradio_wrapper(inputs=layout["inference"]["inputs"].keys())
|
||||||
|
@ -78,7 +107,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
parser.add_argument("--language", type=str, default="en")
|
parser.add_argument("--language", type=str, default="en")
|
||||||
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
|
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
|
||||||
parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*cfg.dataset.frames_per_second))
|
parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*cfg.dataset.frames_per_second))
|
||||||
parser.add_argument("--max-nar-levels", type=int, default=kwargs["max-nar-levels"])
|
parser.add_argument("--max-nar-levels", type=int, default=0), # kwargs["max-nar-levels"])
|
||||||
parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"])
|
parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"])
|
||||||
parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"])
|
parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"])
|
||||||
parser.add_argument("--min-ar-temp", type=float, default=kwargs["min-ar-temp"])
|
parser.add_argument("--min-ar-temp", type=float, default=kwargs["min-ar-temp"])
|
||||||
|
@ -99,7 +128,9 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
raise ValueError("No reference audio provided.")
|
raise ValueError("No reference audio provided.")
|
||||||
|
|
||||||
tts = init_tts()
|
tts = init_tts()
|
||||||
with timer() as t:
|
|
||||||
|
gr.Info("Inferencing...")
|
||||||
|
with timer("Inferenced in") as t:
|
||||||
wav, sr = tts.inference(
|
wav, sr = tts.inference(
|
||||||
text=args.text,
|
text=args.text,
|
||||||
language=args.language,
|
language=args.language,
|
||||||
|
@ -169,6 +200,7 @@ def get_random_prompt():
|
||||||
|
|
||||||
# setup args
|
# setup args
|
||||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||||
|
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||||
parser.add_argument("--listen", default=None, help="Path for Gradio to listen on")
|
parser.add_argument("--listen", default=None, help="Path for Gradio to listen on")
|
||||||
parser.add_argument("--share", action="store_true")
|
parser.add_argument("--share", action="store_true")
|
||||||
parser.add_argument("--render_markdown", action="store_true", default="VALLE_YAML" in os.environ)
|
parser.add_argument("--render_markdown", action="store_true", default="VALLE_YAML" in os.environ)
|
||||||
|
@ -208,7 +240,7 @@ with ui:
|
||||||
with gr.Column(scale=7):
|
with gr.Column(scale=7):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
|
layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
|
||||||
layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
|
#layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
|
||||||
layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
|
layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
|
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
|
||||||
|
@ -249,6 +281,19 @@ with ui:
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
with gr.Tab("Settings"):
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=7):
|
||||||
|
layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml, label="Model")
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
layout["settings"]["buttons"]["load"] = gr.Button(value="Load Model")
|
||||||
|
|
||||||
|
layout["settings"]["buttons"]["load"].click(
|
||||||
|
fn=load_model,
|
||||||
|
inputs=[ x for x in layout["settings"]["inputs"].values() if x is not None],
|
||||||
|
outputs=[ x for x in layout["settings"]["outputs"].values() if x is not None],
|
||||||
|
)
|
||||||
|
|
||||||
if os.path.exists("README.md") and args.render_markdown:
|
if os.path.exists("README.md") and args.render_markdown:
|
||||||
md = open("README.md", "r", encoding="utf-8").read()
|
md = open("README.md", "r", encoding="utf-8").read()
|
||||||
# remove HF's metadata
|
# remove HF's metadata
|
||||||
|
|
Loading…
Reference in New Issue
Block a user