added option to load from a model state dict directly instead of a yaml (to-do: do this for LoRAs too), automatically download the default model if none is provided
This commit is contained in:
parent
a96f5aee32
commit
ccf71dc1b6
|
@ -21,6 +21,7 @@ from functools import cached_property
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from .utils.distributed import world_size
|
from .utils.distributed import world_size
|
||||||
|
from .utils.io import torch_load
|
||||||
from .utils import set_seed, prune_missing
|
from .utils import set_seed, prune_missing
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
|
@ -30,7 +31,13 @@ class BaseConfig:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cfg_path(self):
|
def cfg_path(self):
|
||||||
return Path(self.yaml_path.parent) if self.yaml_path is not None else Path(__file__).parent.parent / "data"
|
if self.yaml_path:
|
||||||
|
return Path(self.yaml_path.parent)
|
||||||
|
|
||||||
|
if self.model_path:
|
||||||
|
return Path(self.model_path.parent)
|
||||||
|
|
||||||
|
return Path(__file__).parent.parent / "data"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def rel_path(self):
|
def rel_path(self):
|
||||||
|
@ -93,8 +100,6 @@ class BaseConfig:
|
||||||
def prune_missing( cls, yaml ):
|
def prune_missing( cls, yaml ):
|
||||||
default = cls(**{})
|
default = cls(**{})
|
||||||
default.format()
|
default.format()
|
||||||
#default = json.loads(default.dumps())
|
|
||||||
|
|
||||||
yaml, missing = prune_missing( source=default, dest=yaml )
|
yaml, missing = prune_missing( source=default, dest=yaml )
|
||||||
if missing:
|
if missing:
|
||||||
_logger.warning(f'Missing keys in YAML: {missing}')
|
_logger.warning(f'Missing keys in YAML: {missing}')
|
||||||
|
@ -108,6 +113,17 @@ class BaseConfig:
|
||||||
state = cls.prune_missing( state )
|
state = cls.prune_missing( state )
|
||||||
return cls(**state)
|
return cls(**state)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_model( cls, model_path ):
|
||||||
|
if not model_path.exists():
|
||||||
|
raise Exception(f'Model path does not exist: {model_path}')
|
||||||
|
|
||||||
|
# load state dict and copy its stored model config
|
||||||
|
state_dict = torch_load( model_path )["config"]
|
||||||
|
|
||||||
|
state = { "models": [ state_dict ], "trainer": { "load_state_dict": True }, "model_path": model_path }
|
||||||
|
return cls(**state)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli(cls, args=sys.argv):
|
def from_cli(cls, args=sys.argv):
|
||||||
# legacy support for yaml=`` format
|
# legacy support for yaml=`` format
|
||||||
|
@ -117,8 +133,12 @@ class BaseConfig:
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
|
parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
|
||||||
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
|
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||||
|
parser.add_argument("--model", type=Path, default=os.environ.get('VALLE_MODEL', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||||
args, unknown = parser.parse_known_args(args=args)
|
args, unknown = parser.parse_known_args(args=args)
|
||||||
|
|
||||||
|
if args.model:
|
||||||
|
return cls.from_model( args.model )
|
||||||
|
|
||||||
if args.yaml:
|
if args.yaml:
|
||||||
return cls.from_yaml( args.yaml )
|
return cls.from_yaml( args.yaml )
|
||||||
|
|
||||||
|
@ -807,10 +827,14 @@ class Config(BaseConfig):
|
||||||
return diskcache.Cache(self.cache_dir).memoize
|
return diskcache.Cache(self.cache_dir).memoize
|
||||||
return lambda: lambda x: x
|
return lambda: lambda x: x
|
||||||
|
|
||||||
# I don't remember why this is needed
|
# this gets called from vall_e.inference
|
||||||
def load_yaml( self, config_path ):
|
def load_yaml( self, config_path ):
|
||||||
tmp = Config.from_yaml( config_path )
|
tmp = Config.from_yaml( config_path )
|
||||||
self.__dict__.update(tmp.__dict__)
|
self.__dict__.update(tmp.__dict__)
|
||||||
|
|
||||||
|
def load_model( self, config_path ):
|
||||||
|
tmp = Config.from_model( config_path )
|
||||||
|
self.__dict__.update(tmp.__dict__)
|
||||||
|
|
||||||
def load_hdf5( self, write=False ):
|
def load_hdf5( self, write=False ):
|
||||||
if hasattr(self, 'hdf5'):
|
if hasattr(self, 'hdf5'):
|
||||||
|
@ -870,7 +894,27 @@ class Config(BaseConfig):
|
||||||
if isinstance(self.optimizations, type):
|
if isinstance(self.optimizations, type):
|
||||||
self.optimizations = dict()
|
self.optimizations = dict()
|
||||||
|
|
||||||
self.dataset = Dataset(**self.dataset)
|
if isinstance( self.dataset, dict ):
|
||||||
|
self.dataset = Dataset(**self.dataset)
|
||||||
|
|
||||||
|
if isinstance( self.hyperparameters, dict ):
|
||||||
|
self.hyperparameters = Hyperparameters(**self.hyperparameters)
|
||||||
|
|
||||||
|
if isinstance( self.evaluation, dict ):
|
||||||
|
self.evaluation = Evaluation(**self.evaluation)
|
||||||
|
|
||||||
|
if isinstance( self.trainer, dict ):
|
||||||
|
self.trainer = Trainer(**self.trainer)
|
||||||
|
|
||||||
|
if isinstance( self.trainer.deepspeed, dict ):
|
||||||
|
self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed)
|
||||||
|
|
||||||
|
if isinstance( self.inference, dict ):
|
||||||
|
self.inference = Inference(**self.inference)
|
||||||
|
|
||||||
|
if isinstance( self.optimizations, dict ):
|
||||||
|
self.optimizations = Optimizations(**self.optimizations)
|
||||||
|
|
||||||
# convert to expanded paths
|
# convert to expanded paths
|
||||||
self.dataset.training = [ self.expand(dir) for dir in self.dataset.training ]
|
self.dataset.training = [ self.expand(dir) for dir in self.dataset.training ]
|
||||||
self.dataset.validation = [ self.expand(dir) for dir in self.dataset.validation ]
|
self.dataset.validation = [ self.expand(dir) for dir in self.dataset.validation ]
|
||||||
|
@ -906,28 +950,15 @@ class Config(BaseConfig):
|
||||||
model["experimental"]["audio_embedding_sums"] = model.pop("audio_embedding_sums")
|
model["experimental"]["audio_embedding_sums"] = model.pop("audio_embedding_sums")
|
||||||
|
|
||||||
|
|
||||||
self.models = [ Model(**model) for model in self.models ]
|
self.models = [ Model(**model) if isinstance(model, dict) else model for model in self.models ]
|
||||||
self.loras = [ LoRA(**lora) for lora in self.loras ]
|
self.loras = [ LoRA(**lora) if isinstance(lora, dict) else lora for lora in self.loras ]
|
||||||
|
|
||||||
if not self.models:
|
if not self.models:
|
||||||
self.models = [ Model() ]
|
self.models = [ Model() ]
|
||||||
|
|
||||||
for model in self.models:
|
for model in self.models:
|
||||||
if not isinstance( model.experimental, dict ):
|
if isinstance( model.experimental, dict ):
|
||||||
continue
|
model.experimental = ModelExperimentalSettings(**model.experimental)
|
||||||
model.experimental = ModelExperimentalSettings(**model.experimental)
|
|
||||||
|
|
||||||
self.hyperparameters = Hyperparameters(**self.hyperparameters)
|
|
||||||
|
|
||||||
self.evaluation = Evaluation(**self.evaluation)
|
|
||||||
|
|
||||||
self.trainer = Trainer(**self.trainer)
|
|
||||||
|
|
||||||
if not isinstance(self.trainer.deepspeed, type):
|
|
||||||
self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed)
|
|
||||||
|
|
||||||
self.inference = Inference(**self.inference)
|
|
||||||
self.optimizations = Optimizations(**self.optimizations)
|
|
||||||
|
|
||||||
if self.hyperparameters.scheduler_type and not self.hyperparameters.scheduler:
|
if self.hyperparameters.scheduler_type and not self.hyperparameters.scheduler:
|
||||||
self.hyperparameters.scheduler = self.hyperparameters.scheduler_type
|
self.hyperparameters.scheduler = self.hyperparameters.scheduler_type
|
||||||
|
@ -961,7 +992,7 @@ class Config(BaseConfig):
|
||||||
try:
|
try:
|
||||||
from transformers import PreTrainedTokenizerFast
|
from transformers import PreTrainedTokenizerFast
|
||||||
|
|
||||||
tokenizer_path = self.rel_path / self.tokenizer_path if self.yaml_path is not None else None
|
tokenizer_path = self.rel_path / self.tokenizer_path
|
||||||
if tokenizer_path and not tokenizer_path.exists():
|
if tokenizer_path and not tokenizer_path.exists():
|
||||||
tokenizer_path = Path("./data/") / self.tokenizer_path
|
tokenizer_path = Path("./data/") / self.tokenizer_path
|
||||||
|
|
||||||
|
|
|
@ -57,6 +57,10 @@ def load_engines(training=True, **model_kwargs):
|
||||||
tag = open(checkpoint_path).read()
|
tag = open(checkpoint_path).read()
|
||||||
checkpoint_path = pick_path( checkpoint_path.parent / tag / f"state.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
|
checkpoint_path = pick_path( checkpoint_path.parent / tag / f"state.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
|
||||||
|
|
||||||
|
# if loaded using --model=
|
||||||
|
if cfg.model_path and cfg.model_path.exists():
|
||||||
|
load_path = cfg.model_path
|
||||||
|
|
||||||
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
|
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
|
||||||
_logger.warning(f"Checkpoint missing, but weights found: {load_path}")
|
_logger.warning(f"Checkpoint missing, but weights found: {load_path}")
|
||||||
loads_state_dict = True
|
loads_state_dict = True
|
||||||
|
|
|
@ -19,6 +19,7 @@ from .models import get_models
|
||||||
from .models.lora import enable_lora
|
from .models.lora import enable_lora
|
||||||
from .engines import load_engines, deepspeed_available
|
from .engines import load_engines, deepspeed_available
|
||||||
from .data import get_phone_symmap, get_lang_symmap, _load_quants, _cleanup_phones, tokenize
|
from .data import get_phone_symmap, get_lang_symmap, _load_quants, _cleanup_phones, tokenize
|
||||||
|
from .models import download_model, DEFAULT_MODEL_PATH
|
||||||
|
|
||||||
if deepspeed_available:
|
if deepspeed_available:
|
||||||
import deepspeed
|
import deepspeed
|
||||||
|
@ -34,9 +35,18 @@ class TTS():
|
||||||
self.loading = False
|
self.loading = False
|
||||||
|
|
||||||
def load_config( self, config=None, device=None, amp=None, dtype=None, attention=None ):
|
def load_config( self, config=None, device=None, amp=None, dtype=None, attention=None ):
|
||||||
if config:
|
if not config:
|
||||||
|
download_model()
|
||||||
|
config = DEFAULT_MODEL_PATH
|
||||||
|
|
||||||
|
if config.suffix == ".yaml":
|
||||||
_logger.info(f"Loading YAML: {config}")
|
_logger.info(f"Loading YAML: {config}")
|
||||||
cfg.load_yaml( config )
|
cfg.load_yaml( config )
|
||||||
|
elif config.suffix == ".sft":
|
||||||
|
_logger.info(f"Loading model: {config}")
|
||||||
|
cfg.load_model( config )
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unknown config passed: {config}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cfg.format( training=False )
|
cfg.format( training=False )
|
||||||
|
|
|
@ -4,31 +4,20 @@ import requests
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# to-do: implement automatically downloading model
|
# to-do: implement automatically downloading model
|
||||||
DEFAULT_MODEL_PATH = Path(__file__).parent.parent.parent / 'data/models'
|
DEFAULT_MODEL_DIR = Path(__file__).parent.parent.parent / 'data/models'
|
||||||
|
DEFAULT_MODEL_PATH = DEFAULT_MODEL_DIR / "ar+nar-llama-8.sft"
|
||||||
DEFAULT_MODEL_URLS = {
|
DEFAULT_MODEL_URLS = {
|
||||||
'ar+nar-llama-8/fp32.sft': 'https://huggingface.co/ecker/vall-e/resolve/main/models/ckpt/ar%2Bnar-llama-8/fp32.sft',
|
'ar+nar-llama-8.sft': 'https://huggingface.co/ecker/vall-e/resolve/main/models/ckpt/ar%2Bnar-llama-8/fp32.sft',
|
||||||
}
|
}
|
||||||
|
|
||||||
# kludge, probably better to use HF's model downloader function
|
# kludge, probably better to use HF's model downloader function
|
||||||
# to-do: write to a temp file then copy so downloads can be interrupted
|
# to-do: write to a temp file then copy so downloads can be interrupted
|
||||||
def download_model( save_path, chunkSize = 1024, unit = "MiB" ):
|
def download_model( save_path=DEFAULT_MODEL_PATH, chunkSize = 1024 ):
|
||||||
scale = 1
|
|
||||||
if unit == "KiB":
|
|
||||||
scale = (1024)
|
|
||||||
elif unit == "MiB":
|
|
||||||
scale = (1024 * 1024)
|
|
||||||
elif unit == "MiB":
|
|
||||||
scale = (1024 * 1024 * 1024)
|
|
||||||
elif unit == "KB":
|
|
||||||
scale = (1000)
|
|
||||||
elif unit == "MB":
|
|
||||||
scale = (1000 * 1000)
|
|
||||||
elif unit == "MB":
|
|
||||||
scale = (1000 * 1000 * 1000)
|
|
||||||
|
|
||||||
name = save_path.name
|
name = save_path.name
|
||||||
url = DEFAULT_MODEL_URLS[name] if name in DEFAULT_MODEL_URLS else None
|
url = DEFAULT_MODEL_URLS[name] if name in DEFAULT_MODEL_URLS else None
|
||||||
if url is None:
|
if url is None:
|
||||||
|
@ -37,19 +26,32 @@ def download_model( save_path, chunkSize = 1024, unit = "MiB" ):
|
||||||
if not save_path.parent.exists():
|
if not save_path.parent.exists():
|
||||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
r = requests.get(url, stream=True)
|
headers = {}
|
||||||
content_length = int(r.headers['Content-Length'] if 'Content-Length' in r.headers else r.headers['content-length']) // scale
|
# check if modified
|
||||||
|
if save_path.exists():
|
||||||
|
headers = {"If-Modified-Since": time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime(save_path.stat().st_mtime))}
|
||||||
|
|
||||||
|
r = requests.get(url, headers=headers, stream=True)
|
||||||
|
|
||||||
|
# not modified
|
||||||
|
if r.status_code == 304:
|
||||||
|
r.close()
|
||||||
|
return
|
||||||
|
|
||||||
|
# to-do: validate lengths match
|
||||||
|
|
||||||
|
content_length = int(r.headers['Content-Length'] if 'Content-Length' in r.headers else r.headers['content-length'])
|
||||||
with open(save_path, 'wb') as f:
|
with open(save_path, 'wb') as f:
|
||||||
bar = tqdm( unit=unit, total=content_length )
|
bar = tqdm( unit='B', unit_scale=True, unit_divisor=1024, total=content_length, desc=f"Downloading: {name}" )
|
||||||
for chunk in r.iter_content(chunk_size=chunkSize):
|
for chunk in r.iter_content(chunk_size=chunkSize):
|
||||||
if not chunk:
|
if not chunk:
|
||||||
continue
|
continue
|
||||||
|
bar.update( len(chunk))
|
||||||
bar.update( len(chunk) / scale )
|
|
||||||
f.write(chunk)
|
f.write(chunk)
|
||||||
bar.close()
|
bar.close()
|
||||||
|
|
||||||
|
r.close()
|
||||||
|
|
||||||
|
|
||||||
def get_model(config, training=True, **model_kwargs):
|
def get_model(config, training=True, **model_kwargs):
|
||||||
name = config.name
|
name = config.name
|
||||||
|
|
|
@ -21,6 +21,7 @@ from .utils.io import json_read, json_stringify
|
||||||
from .emb.qnt import decode_to_wave
|
from .emb.qnt import decode_to_wave
|
||||||
from .data import get_lang_symmap, get_random_prompt
|
from .data import get_lang_symmap, get_random_prompt
|
||||||
|
|
||||||
|
|
||||||
tts = None
|
tts = None
|
||||||
|
|
||||||
layout = {}
|
layout = {}
|
||||||
|
@ -49,9 +50,9 @@ def gradio_wrapper(inputs):
|
||||||
return wrapped_function
|
return wrapped_function
|
||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
# returns a list of models, assuming the models are placed under ./training/ or ./models/
|
# returns a list of models, assuming the models are placed under ./training/ or ./models/ or ./data/models/
|
||||||
def get_model_paths( paths=[Path("./training/"), Path("./models/")] ):
|
def get_model_paths( paths=[Path("./training/"), Path("./models/"), Path("./data/models/")] ):
|
||||||
yamls = []
|
configs = []
|
||||||
|
|
||||||
for path in paths:
|
for path in paths:
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
|
@ -60,10 +61,14 @@ def get_model_paths( paths=[Path("./training/"), Path("./models/")] ):
|
||||||
for yaml in path.glob("**/*.yaml"):
|
for yaml in path.glob("**/*.yaml"):
|
||||||
if "/logs/" in str(yaml):
|
if "/logs/" in str(yaml):
|
||||||
continue
|
continue
|
||||||
|
configs.append( yaml )
|
||||||
|
|
||||||
|
for sft in path.glob("**/*.sft"):
|
||||||
|
if "/logs/" in str(sft):
|
||||||
|
continue
|
||||||
|
configs.append( sft )
|
||||||
|
|
||||||
yamls.append( yaml )
|
return configs
|
||||||
|
|
||||||
return yamls
|
|
||||||
|
|
||||||
def get_dtypes():
|
def get_dtypes():
|
||||||
return ["float32", "float16", "bfloat16", "float8_e5m2", "float8_e4m3fn", "auto"]
|
return ["float32", "float16", "bfloat16", "float8_e5m2", "float8_e4m3fn", "auto"]
|
||||||
|
@ -73,10 +78,10 @@ def get_attentions():
|
||||||
return AVAILABLE_ATTENTIONS + ["auto"]
|
return AVAILABLE_ATTENTIONS + ["auto"]
|
||||||
|
|
||||||
#@gradio_wrapper(inputs=layout["settings"]["inputs"].keys())
|
#@gradio_wrapper(inputs=layout["settings"]["inputs"].keys())
|
||||||
def load_model( yaml, device, dtype, attention ):
|
def load_model( config, device, dtype, attention ):
|
||||||
gr.Info(f"Loading: {yaml}")
|
gr.Info(f"Loading: {config}")
|
||||||
try:
|
try:
|
||||||
init_tts( yaml=Path(yaml), restart=True, device=device, dtype=dtype, attention=attention )
|
init_tts( config=Path(config), restart=True, device=device, dtype=dtype, attention=attention )
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise gr.Error(e)
|
raise gr.Error(e)
|
||||||
gr.Info(f"Loaded model")
|
gr.Info(f"Loaded model")
|
||||||
|
@ -107,7 +112,7 @@ def load_sample( speaker ):
|
||||||
|
|
||||||
return data, (sr, wav)
|
return data, (sr, wav)
|
||||||
|
|
||||||
def init_tts(yaml=None, restart=False, device="cuda", dtype="auto", attention=None):
|
def init_tts(config=None, restart=False, device="cuda", dtype="auto", attention=None):
|
||||||
global tts
|
global tts
|
||||||
|
|
||||||
if tts is not None:
|
if tts is not None:
|
||||||
|
@ -118,20 +123,32 @@ def init_tts(yaml=None, restart=False, device="cuda", dtype="auto", attention=No
|
||||||
tts = None
|
tts = None
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
|
parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
|
||||||
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', yaml)) # os environ so it can be specified in a HuggingFace Space too
|
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||||
|
parser.add_argument("--model", type=Path, default=os.environ.get('VALLE_MODEL', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||||
parser.add_argument("--device", type=str, default=device)
|
parser.add_argument("--device", type=str, default=device)
|
||||||
parser.add_argument("--amp", action="store_true")
|
parser.add_argument("--amp", action="store_true")
|
||||||
parser.add_argument("--dtype", type=str, default=dtype)
|
parser.add_argument("--dtype", type=str, default=dtype)
|
||||||
parser.add_argument("--attention", type=str, default=attention)
|
parser.add_argument("--attention", type=str, default=attention)
|
||||||
args, unknown = parser.parse_known_args()
|
args, unknown = parser.parse_known_args()
|
||||||
|
|
||||||
tts = TTS( config=args.yaml if yaml is None else yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp, attention=args.attention )
|
if config:
|
||||||
|
if config.suffix == ".yaml" and not args.yaml:
|
||||||
|
args.yaml = config
|
||||||
|
elif config.suffix == ".sft" and not args.model:
|
||||||
|
args.model = config
|
||||||
|
|
||||||
|
if args.yaml:
|
||||||
|
config = args.yaml
|
||||||
|
elif args.model:
|
||||||
|
config = args.model
|
||||||
|
|
||||||
|
tts = TTS( config=config, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp, attention=args.attention )
|
||||||
return tts
|
return tts
|
||||||
|
|
||||||
@gradio_wrapper(inputs=layout["inference_tts"]["inputs"].keys())
|
@gradio_wrapper(inputs=layout["inference_tts"]["inputs"].keys())
|
||||||
def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
if not cfg.yaml_path:
|
if not cfg.models:
|
||||||
raise Exception("No YAML loaded.")
|
raise Exception("No model loaded.")
|
||||||
|
|
||||||
if kwargs.pop("dynamic-sampling", False):
|
if kwargs.pop("dynamic-sampling", False):
|
||||||
kwargs['min-ar-temp'] = 0.01 if kwargs['ar-temp'] > 0.01 else 0.0
|
kwargs['min-ar-temp'] = 0.01 if kwargs['ar-temp'] > 0.01 else 0.0
|
||||||
|
@ -220,8 +237,8 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
|
|
||||||
@gradio_wrapper(inputs=layout["inference_stt"]["inputs"].keys())
|
@gradio_wrapper(inputs=layout["inference_stt"]["inputs"].keys())
|
||||||
def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
if not cfg.yaml_path:
|
if not cfg.models:
|
||||||
raise Exception("No YAML loaded.")
|
raise Exception("No model loaded.")
|
||||||
|
|
||||||
if kwargs.pop("dynamic-sampling", False):
|
if kwargs.pop("dynamic-sampling", False):
|
||||||
kwargs['min-ar-temp'] = 0.85 if kwargs['ar-temp'] > 0.85 else 0.0
|
kwargs['min-ar-temp'] = 0.85 if kwargs['ar-temp'] > 0.85 else 0.0
|
||||||
|
@ -306,6 +323,7 @@ def do_training( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
# setup args
|
# setup args
|
||||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||||
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
|
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||||
|
parser.add_argument("--model", type=Path, default=os.environ.get('VALLE_MODEL', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||||
parser.add_argument("--listen", default=None, help="Path for Gradio to listen on")
|
parser.add_argument("--listen", default=None, help="Path for Gradio to listen on")
|
||||||
parser.add_argument("--share", action="store_true")
|
parser.add_argument("--share", action="store_true")
|
||||||
parser.add_argument("--render_markdown", action="store_true", default="VALLE_YAML" in os.environ)
|
parser.add_argument("--render_markdown", action="store_true", default="VALLE_YAML" in os.environ)
|
||||||
|
@ -462,7 +480,7 @@ with ui:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=7):
|
with gr.Column(scale=7):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml, label="Model")
|
layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml or args.model, label="Model")
|
||||||
layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device")
|
layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device")
|
||||||
layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision")
|
layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision")
|
||||||
layout["settings"]["inputs"]["attentions"] = gr.Dropdown(choices=get_attentions(), value="auto", label="Attentions")
|
layout["settings"]["inputs"]["attentions"] = gr.Dropdown(choices=get_attentions(), value="auto", label="Attentions")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user