added option to load from a model state dict directly instead of a yaml (to-do: do this for LoRAs too), automatically download the default model if none is provided
This commit is contained in:
parent
a96f5aee32
commit
ccf71dc1b6
|
@ -21,6 +21,7 @@ from functools import cached_property
|
|||
from pathlib import Path
|
||||
|
||||
from .utils.distributed import world_size
|
||||
from .utils.io import torch_load
|
||||
from .utils import set_seed, prune_missing
|
||||
|
||||
@dataclass()
|
||||
|
@ -30,7 +31,13 @@ class BaseConfig:
|
|||
|
||||
@property
|
||||
def cfg_path(self):
|
||||
return Path(self.yaml_path.parent) if self.yaml_path is not None else Path(__file__).parent.parent / "data"
|
||||
if self.yaml_path:
|
||||
return Path(self.yaml_path.parent)
|
||||
|
||||
if self.model_path:
|
||||
return Path(self.model_path.parent)
|
||||
|
||||
return Path(__file__).parent.parent / "data"
|
||||
|
||||
@property
|
||||
def rel_path(self):
|
||||
|
@ -93,8 +100,6 @@ class BaseConfig:
|
|||
def prune_missing( cls, yaml ):
|
||||
default = cls(**{})
|
||||
default.format()
|
||||
#default = json.loads(default.dumps())
|
||||
|
||||
yaml, missing = prune_missing( source=default, dest=yaml )
|
||||
if missing:
|
||||
_logger.warning(f'Missing keys in YAML: {missing}')
|
||||
|
@ -108,6 +113,17 @@ class BaseConfig:
|
|||
state = cls.prune_missing( state )
|
||||
return cls(**state)
|
||||
|
||||
@classmethod
|
||||
def from_model( cls, model_path ):
|
||||
if not model_path.exists():
|
||||
raise Exception(f'Model path does not exist: {model_path}')
|
||||
|
||||
# load state dict and copy its stored model config
|
||||
state_dict = torch_load( model_path )["config"]
|
||||
|
||||
state = { "models": [ state_dict ], "trainer": { "load_state_dict": True }, "model_path": model_path }
|
||||
return cls(**state)
|
||||
|
||||
@classmethod
|
||||
def from_cli(cls, args=sys.argv):
|
||||
# legacy support for yaml=`` format
|
||||
|
@ -117,8 +133,12 @@ class BaseConfig:
|
|||
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
|
||||
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||
parser.add_argument("--model", type=Path, default=os.environ.get('VALLE_MODEL', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||
args, unknown = parser.parse_known_args(args=args)
|
||||
|
||||
if args.model:
|
||||
return cls.from_model( args.model )
|
||||
|
||||
if args.yaml:
|
||||
return cls.from_yaml( args.yaml )
|
||||
|
||||
|
@ -807,10 +827,14 @@ class Config(BaseConfig):
|
|||
return diskcache.Cache(self.cache_dir).memoize
|
||||
return lambda: lambda x: x
|
||||
|
||||
# I don't remember why this is needed
|
||||
# this gets called from vall_e.inference
|
||||
def load_yaml( self, config_path ):
|
||||
tmp = Config.from_yaml( config_path )
|
||||
self.__dict__.update(tmp.__dict__)
|
||||
|
||||
def load_model( self, config_path ):
|
||||
tmp = Config.from_model( config_path )
|
||||
self.__dict__.update(tmp.__dict__)
|
||||
|
||||
def load_hdf5( self, write=False ):
|
||||
if hasattr(self, 'hdf5'):
|
||||
|
@ -870,7 +894,27 @@ class Config(BaseConfig):
|
|||
if isinstance(self.optimizations, type):
|
||||
self.optimizations = dict()
|
||||
|
||||
self.dataset = Dataset(**self.dataset)
|
||||
if isinstance( self.dataset, dict ):
|
||||
self.dataset = Dataset(**self.dataset)
|
||||
|
||||
if isinstance( self.hyperparameters, dict ):
|
||||
self.hyperparameters = Hyperparameters(**self.hyperparameters)
|
||||
|
||||
if isinstance( self.evaluation, dict ):
|
||||
self.evaluation = Evaluation(**self.evaluation)
|
||||
|
||||
if isinstance( self.trainer, dict ):
|
||||
self.trainer = Trainer(**self.trainer)
|
||||
|
||||
if isinstance( self.trainer.deepspeed, dict ):
|
||||
self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed)
|
||||
|
||||
if isinstance( self.inference, dict ):
|
||||
self.inference = Inference(**self.inference)
|
||||
|
||||
if isinstance( self.optimizations, dict ):
|
||||
self.optimizations = Optimizations(**self.optimizations)
|
||||
|
||||
# convert to expanded paths
|
||||
self.dataset.training = [ self.expand(dir) for dir in self.dataset.training ]
|
||||
self.dataset.validation = [ self.expand(dir) for dir in self.dataset.validation ]
|
||||
|
@ -906,28 +950,15 @@ class Config(BaseConfig):
|
|||
model["experimental"]["audio_embedding_sums"] = model.pop("audio_embedding_sums")
|
||||
|
||||
|
||||
self.models = [ Model(**model) for model in self.models ]
|
||||
self.loras = [ LoRA(**lora) for lora in self.loras ]
|
||||
self.models = [ Model(**model) if isinstance(model, dict) else model for model in self.models ]
|
||||
self.loras = [ LoRA(**lora) if isinstance(lora, dict) else lora for lora in self.loras ]
|
||||
|
||||
if not self.models:
|
||||
self.models = [ Model() ]
|
||||
|
||||
for model in self.models:
|
||||
if not isinstance( model.experimental, dict ):
|
||||
continue
|
||||
model.experimental = ModelExperimentalSettings(**model.experimental)
|
||||
|
||||
self.hyperparameters = Hyperparameters(**self.hyperparameters)
|
||||
|
||||
self.evaluation = Evaluation(**self.evaluation)
|
||||
|
||||
self.trainer = Trainer(**self.trainer)
|
||||
|
||||
if not isinstance(self.trainer.deepspeed, type):
|
||||
self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed)
|
||||
|
||||
self.inference = Inference(**self.inference)
|
||||
self.optimizations = Optimizations(**self.optimizations)
|
||||
if isinstance( model.experimental, dict ):
|
||||
model.experimental = ModelExperimentalSettings(**model.experimental)
|
||||
|
||||
if self.hyperparameters.scheduler_type and not self.hyperparameters.scheduler:
|
||||
self.hyperparameters.scheduler = self.hyperparameters.scheduler_type
|
||||
|
@ -961,7 +992,7 @@ class Config(BaseConfig):
|
|||
try:
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
tokenizer_path = self.rel_path / self.tokenizer_path if self.yaml_path is not None else None
|
||||
tokenizer_path = self.rel_path / self.tokenizer_path
|
||||
if tokenizer_path and not tokenizer_path.exists():
|
||||
tokenizer_path = Path("./data/") / self.tokenizer_path
|
||||
|
||||
|
|
|
@ -57,6 +57,10 @@ def load_engines(training=True, **model_kwargs):
|
|||
tag = open(checkpoint_path).read()
|
||||
checkpoint_path = pick_path( checkpoint_path.parent / tag / f"state.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
|
||||
|
||||
# if loaded using --model=
|
||||
if cfg.model_path and cfg.model_path.exists():
|
||||
load_path = cfg.model_path
|
||||
|
||||
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
|
||||
_logger.warning(f"Checkpoint missing, but weights found: {load_path}")
|
||||
loads_state_dict = True
|
||||
|
|
|
@ -19,6 +19,7 @@ from .models import get_models
|
|||
from .models.lora import enable_lora
|
||||
from .engines import load_engines, deepspeed_available
|
||||
from .data import get_phone_symmap, get_lang_symmap, _load_quants, _cleanup_phones, tokenize
|
||||
from .models import download_model, DEFAULT_MODEL_PATH
|
||||
|
||||
if deepspeed_available:
|
||||
import deepspeed
|
||||
|
@ -34,9 +35,18 @@ class TTS():
|
|||
self.loading = False
|
||||
|
||||
def load_config( self, config=None, device=None, amp=None, dtype=None, attention=None ):
|
||||
if config:
|
||||
if not config:
|
||||
download_model()
|
||||
config = DEFAULT_MODEL_PATH
|
||||
|
||||
if config.suffix == ".yaml":
|
||||
_logger.info(f"Loading YAML: {config}")
|
||||
cfg.load_yaml( config )
|
||||
elif config.suffix == ".sft":
|
||||
_logger.info(f"Loading model: {config}")
|
||||
cfg.load_model( config )
|
||||
else:
|
||||
raise Exception(f"Unknown config passed: {config}")
|
||||
|
||||
try:
|
||||
cfg.format( training=False )
|
||||
|
|
|
@ -4,31 +4,20 @@ import requests
|
|||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
|
||||
import time
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
# to-do: implement automatically downloading model
|
||||
DEFAULT_MODEL_PATH = Path(__file__).parent.parent.parent / 'data/models'
|
||||
DEFAULT_MODEL_DIR = Path(__file__).parent.parent.parent / 'data/models'
|
||||
DEFAULT_MODEL_PATH = DEFAULT_MODEL_DIR / "ar+nar-llama-8.sft"
|
||||
DEFAULT_MODEL_URLS = {
|
||||
'ar+nar-llama-8/fp32.sft': 'https://huggingface.co/ecker/vall-e/resolve/main/models/ckpt/ar%2Bnar-llama-8/fp32.sft',
|
||||
'ar+nar-llama-8.sft': 'https://huggingface.co/ecker/vall-e/resolve/main/models/ckpt/ar%2Bnar-llama-8/fp32.sft',
|
||||
}
|
||||
|
||||
# kludge, probably better to use HF's model downloader function
|
||||
# to-do: write to a temp file then copy so downloads can be interrupted
|
||||
def download_model( save_path, chunkSize = 1024, unit = "MiB" ):
|
||||
scale = 1
|
||||
if unit == "KiB":
|
||||
scale = (1024)
|
||||
elif unit == "MiB":
|
||||
scale = (1024 * 1024)
|
||||
elif unit == "MiB":
|
||||
scale = (1024 * 1024 * 1024)
|
||||
elif unit == "KB":
|
||||
scale = (1000)
|
||||
elif unit == "MB":
|
||||
scale = (1000 * 1000)
|
||||
elif unit == "MB":
|
||||
scale = (1000 * 1000 * 1000)
|
||||
|
||||
def download_model( save_path=DEFAULT_MODEL_PATH, chunkSize = 1024 ):
|
||||
name = save_path.name
|
||||
url = DEFAULT_MODEL_URLS[name] if name in DEFAULT_MODEL_URLS else None
|
||||
if url is None:
|
||||
|
@ -37,19 +26,32 @@ def download_model( save_path, chunkSize = 1024, unit = "MiB" ):
|
|||
if not save_path.parent.exists():
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
r = requests.get(url, stream=True)
|
||||
content_length = int(r.headers['Content-Length'] if 'Content-Length' in r.headers else r.headers['content-length']) // scale
|
||||
headers = {}
|
||||
# check if modified
|
||||
if save_path.exists():
|
||||
headers = {"If-Modified-Since": time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime(save_path.stat().st_mtime))}
|
||||
|
||||
r = requests.get(url, headers=headers, stream=True)
|
||||
|
||||
# not modified
|
||||
if r.status_code == 304:
|
||||
r.close()
|
||||
return
|
||||
|
||||
# to-do: validate lengths match
|
||||
|
||||
content_length = int(r.headers['Content-Length'] if 'Content-Length' in r.headers else r.headers['content-length'])
|
||||
with open(save_path, 'wb') as f:
|
||||
bar = tqdm( unit=unit, total=content_length )
|
||||
bar = tqdm( unit='B', unit_scale=True, unit_divisor=1024, total=content_length, desc=f"Downloading: {name}" )
|
||||
for chunk in r.iter_content(chunk_size=chunkSize):
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
bar.update( len(chunk) / scale )
|
||||
bar.update( len(chunk))
|
||||
f.write(chunk)
|
||||
bar.close()
|
||||
|
||||
r.close()
|
||||
|
||||
|
||||
def get_model(config, training=True, **model_kwargs):
|
||||
name = config.name
|
||||
|
|
|
@ -21,6 +21,7 @@ from .utils.io import json_read, json_stringify
|
|||
from .emb.qnt import decode_to_wave
|
||||
from .data import get_lang_symmap, get_random_prompt
|
||||
|
||||
|
||||
tts = None
|
||||
|
||||
layout = {}
|
||||
|
@ -49,9 +50,9 @@ def gradio_wrapper(inputs):
|
|||
return wrapped_function
|
||||
return decorated
|
||||
|
||||
# returns a list of models, assuming the models are placed under ./training/ or ./models/
|
||||
def get_model_paths( paths=[Path("./training/"), Path("./models/")] ):
|
||||
yamls = []
|
||||
# returns a list of models, assuming the models are placed under ./training/ or ./models/ or ./data/models/
|
||||
def get_model_paths( paths=[Path("./training/"), Path("./models/"), Path("./data/models/")] ):
|
||||
configs = []
|
||||
|
||||
for path in paths:
|
||||
if not path.exists():
|
||||
|
@ -60,10 +61,14 @@ def get_model_paths( paths=[Path("./training/"), Path("./models/")] ):
|
|||
for yaml in path.glob("**/*.yaml"):
|
||||
if "/logs/" in str(yaml):
|
||||
continue
|
||||
configs.append( yaml )
|
||||
|
||||
for sft in path.glob("**/*.sft"):
|
||||
if "/logs/" in str(sft):
|
||||
continue
|
||||
configs.append( sft )
|
||||
|
||||
yamls.append( yaml )
|
||||
|
||||
return yamls
|
||||
return configs
|
||||
|
||||
def get_dtypes():
|
||||
return ["float32", "float16", "bfloat16", "float8_e5m2", "float8_e4m3fn", "auto"]
|
||||
|
@ -73,10 +78,10 @@ def get_attentions():
|
|||
return AVAILABLE_ATTENTIONS + ["auto"]
|
||||
|
||||
#@gradio_wrapper(inputs=layout["settings"]["inputs"].keys())
|
||||
def load_model( yaml, device, dtype, attention ):
|
||||
gr.Info(f"Loading: {yaml}")
|
||||
def load_model( config, device, dtype, attention ):
|
||||
gr.Info(f"Loading: {config}")
|
||||
try:
|
||||
init_tts( yaml=Path(yaml), restart=True, device=device, dtype=dtype, attention=attention )
|
||||
init_tts( config=Path(config), restart=True, device=device, dtype=dtype, attention=attention )
|
||||
except Exception as e:
|
||||
raise gr.Error(e)
|
||||
gr.Info(f"Loaded model")
|
||||
|
@ -107,7 +112,7 @@ def load_sample( speaker ):
|
|||
|
||||
return data, (sr, wav)
|
||||
|
||||
def init_tts(yaml=None, restart=False, device="cuda", dtype="auto", attention=None):
|
||||
def init_tts(config=None, restart=False, device="cuda", dtype="auto", attention=None):
|
||||
global tts
|
||||
|
||||
if tts is not None:
|
||||
|
@ -118,20 +123,32 @@ def init_tts(yaml=None, restart=False, device="cuda", dtype="auto", attention=No
|
|||
tts = None
|
||||
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
|
||||
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', yaml)) # os environ so it can be specified in a HuggingFace Space too
|
||||
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||
parser.add_argument("--model", type=Path, default=os.environ.get('VALLE_MODEL', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||
parser.add_argument("--device", type=str, default=device)
|
||||
parser.add_argument("--amp", action="store_true")
|
||||
parser.add_argument("--dtype", type=str, default=dtype)
|
||||
parser.add_argument("--attention", type=str, default=attention)
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
tts = TTS( config=args.yaml if yaml is None else yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp, attention=args.attention )
|
||||
if config:
|
||||
if config.suffix == ".yaml" and not args.yaml:
|
||||
args.yaml = config
|
||||
elif config.suffix == ".sft" and not args.model:
|
||||
args.model = config
|
||||
|
||||
if args.yaml:
|
||||
config = args.yaml
|
||||
elif args.model:
|
||||
config = args.model
|
||||
|
||||
tts = TTS( config=config, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp, attention=args.attention )
|
||||
return tts
|
||||
|
||||
@gradio_wrapper(inputs=layout["inference_tts"]["inputs"].keys())
|
||||
def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||
if not cfg.yaml_path:
|
||||
raise Exception("No YAML loaded.")
|
||||
if not cfg.models:
|
||||
raise Exception("No model loaded.")
|
||||
|
||||
if kwargs.pop("dynamic-sampling", False):
|
||||
kwargs['min-ar-temp'] = 0.01 if kwargs['ar-temp'] > 0.01 else 0.0
|
||||
|
@ -220,8 +237,8 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
|
||||
@gradio_wrapper(inputs=layout["inference_stt"]["inputs"].keys())
|
||||
def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||
if not cfg.yaml_path:
|
||||
raise Exception("No YAML loaded.")
|
||||
if not cfg.models:
|
||||
raise Exception("No model loaded.")
|
||||
|
||||
if kwargs.pop("dynamic-sampling", False):
|
||||
kwargs['min-ar-temp'] = 0.85 if kwargs['ar-temp'] > 0.85 else 0.0
|
||||
|
@ -306,6 +323,7 @@ def do_training( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
# setup args
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||
parser.add_argument("--model", type=Path, default=os.environ.get('VALLE_MODEL', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||
parser.add_argument("--listen", default=None, help="Path for Gradio to listen on")
|
||||
parser.add_argument("--share", action="store_true")
|
||||
parser.add_argument("--render_markdown", action="store_true", default="VALLE_YAML" in os.environ)
|
||||
|
@ -462,7 +480,7 @@ with ui:
|
|||
with gr.Row():
|
||||
with gr.Column(scale=7):
|
||||
with gr.Row():
|
||||
layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml, label="Model")
|
||||
layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml or args.model, label="Model")
|
||||
layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device")
|
||||
layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision")
|
||||
layout["settings"]["inputs"]["attentions"] = gr.Dropdown(choices=get_attentions(), value="auto", label="Attentions")
|
||||
|
|
Loading…
Reference in New Issue
Block a user