diff --git a/vall_e/config.py b/vall_e/config.py index 60e05d4..2e6204e 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -460,6 +460,7 @@ class Trainer: @dataclass() class Inference: + backend: str = "local" weight_dtype: str = "float32" amp: bool = False @@ -492,6 +493,7 @@ class BitsAndBytes: class Config(_Config): device: str = "cuda" mode: str = "training" # "inferencing" + experimental: bool = False # So I can stop commenting out things when committing dataset: Dataset = field(default_factory=lambda: Dataset) models: Models = field(default_factory=lambda: Models) diff --git a/vall_e/data.py b/vall_e/data.py index 7876ba5..0772c13 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -291,8 +291,10 @@ class Dataset(_Dataset): # shuffle it up a bit prom_length = 0 - #trim_length = random.randint(75 * 3, 75 * 9) # [3 seconds, 9 seconds] - trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75) + if cfg.experimental: + trim_length = random.randint(75 * 3, 75 * 9) # [3 seconds, 9 seconds] + else: + trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75) for _ in range(cfg.dataset.max_prompts): path = random.choice(choices) diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index e7abae7..30a6de8 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -8,4 +8,112 @@ if cfg.trainer.backend == "deepspeed": elif cfg.trainer.backend == "local": from .base import Engine -from .base import Engines, TrainFeeder, default_feeder, Engine as _Engine \ No newline at end of file +from .base import Engines, TrainFeeder, default_feeder, Engine as _Engine + +from ..models import get_models +from ..utils import wrapper as ml +import torch + +deepspeed_available = False +try: + import deepspeed + deepspeed_available = True +except Exception as e: + pass + +def load_engines(): + models = get_models(cfg.models.get()) + engines = dict() + + for name, model in models.items(): + optimizer = None + lr_scheduler = None + + inferencing = cfg.mode == "inferencing" or not model._cfg.training + backend = cfg.inference.backend if inferencing else cfg.trainer.backend + dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype + amp = cfg.inference.amp if inferencing else cfg.trainer.amp + loads_state_dict = cfg.trainer.load_state_dict or inferencing + + engine_class = _Engine if backend == "local" or inferencing else Engine + + if inferencing: + model._cfg.training = False + + if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer): + optimizer_class = None + params = { + "lr": cfg.hyperparameters.learning_rate, + } + if cfg.hyperparameters.optimizer.lower() == "adamw": + params["betas"] = (0.9, 0.96) + params["eps"] = 1e-07 + params["weight_decay"] = 0.01 + + optimizer_class = ml.AdamW + elif cfg.hyperparameters.optimizer.lower() == "sgd": + optimizer = ml.SGD + elif cfg.hyperparameters.optimizer.lower() == "prodigy": + optimizer_class = ml.Prodigy + else: + raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}') + + params.update(cfg.hyperparameters.optimizer_params) + optimizer = optimizer_class( + [ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ], + **params, + ) + + # set up our LR scheduler here + + if inferencing: + optimizer = None + lr_scheduler = None + + # automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present + if not loads_state_dict and backend == "deepspeed" and not (cfg.ckpt_dir / name / "latest").exists(): + print("DeepSpeed checkpoint missing, but weights found.") + loads_state_dict = True + + stats = None + if loads_state_dict: + load_path = cfg.ckpt_dir / name / "fp32.pth" + state = torch.load(load_path, map_location=torch.device(cfg.device)) + + # state dict is not just the module, extract the extra trainer details + if "stats" in state: + stats = state["stats"] + + if "module" in state: + state = state["module"] + + model.load_state_dict(state, strict=cfg.trainer.strict_loading) + + # deepspeed inferencing + if backend == "local" and inferencing and deepspeed_available: #and sys.platform.startswith("win"): + engine_class = _Engine + model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module + + # use base engine if requested + engines[name] = engine_class( + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + + _cfg=model._cfg, + stats=stats + ) + + engines = Engines(engines) + engines.setup() + + if not cfg.trainer.load_state_dict: + engines.load_checkpoint() + + # freeze requested params + for name, engine in engines.items(): + engine.freeze(freeze_all=False) + + #do_gc() + + return engines \ No newline at end of file diff --git a/vall_e/export.py b/vall_e/export.py index f7778c6..982d870 100755 --- a/vall_e/export.py +++ b/vall_e/export.py @@ -3,7 +3,7 @@ import argparse import torch from .data import get_phone_symmap -from .train import load_engines +from .engines import load_engines from .config import cfg def main(): diff --git a/vall_e/inference.py b/vall_e/inference.py index 594a082..34df7e5 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -12,18 +12,11 @@ from .utils import to_device from .config import cfg from .models import get_models -from .train import load_engines +from .engines import load_engines, deepspeed_available from .data import get_phone_symmap, _load_quants -use_deepspeed_inference = False -# to-do: integrate this for windows -""" -try: +if deepspeed_available: import deepspeed - use_deepspeed_inference = True -except Exception as e: - pass -""" class TTS(): def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device=None, amp=None, dtype=None ): @@ -48,9 +41,9 @@ class TTS(): if device is None: device = cfg.device - cfg.mode = "inferencing" cfg.device = device - cfg.trainer.load_state_dict = True + cfg.mode = "inferencing" + cfg.trainer.backend = cfg.inference.backend cfg.trainer.weight_dtype = dtype cfg.inference.weight_dtype = dtype @@ -70,6 +63,10 @@ class TTS(): state = state['module'] model.load_state_dict(state) + + if deepspeed_available: + model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module + return model if ar_ckpt and nar_ckpt: @@ -94,12 +91,8 @@ class TTS(): self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32) self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32) - if use_deepspeed_inference: - self.ar = deepspeed.init_inference(model=self.ar, mp_size=1, replace_with_kernel_inject=True, dtype=self.dtype if not self.amp else torch.float32).module.eval() - self.nar = deepspeed.init_inference(model=self.nar, mp_size=1, replace_with_kernel_inject=True, dtype=self.dtype if not self.amp else torch.float32).module.eval() - else: - self.ar.eval() - self.nar.eval() + self.ar.eval() + self.nar.eval() if self.symmap is None: self.symmap = get_phone_symmap() diff --git a/vall_e/train.py b/vall_e/train.py index 9179553..91e63cd 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -5,7 +5,6 @@ from .data import create_train_val_dataloader from .emb import qnt from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc -from .utils.trainer import load_engines import auraloss import json diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index 830e8a0..2214f9b 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -28,8 +28,7 @@ from .distributed import ( local_leader_only, ) -from ..engines import _Engine, Engine, Engines, TrainFeeder, default_feeder -from ..models import get_models +from ..engines import _Engine, Engine, Engines, TrainFeeder, default_feeder, load_engines from .utils import to_device, do_gc from ..utils import wrapper as ml @@ -38,86 +37,6 @@ from ..data import get_phone_symmap # should decouple from this trainer script _logger = logging.getLogger(__name__) _command: str -def load_engines(): - models = get_models(cfg.models.get()) - engines = dict() - - for name, model in models.items(): - optimizer = None - lr_scheduler = None - - if cfg.trainer.backend == "local" or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.torch_optimizer): - optimizer_class = None - params = { - "lr": cfg.hyperparameters.learning_rate, - } - if cfg.hyperparameters.optimizer.lower() == "adamw": - params["betas"] = (0.9, 0.96) - params["eps"] = 1e-07 - params["weight_decay"] = 0.01 - - optimizer_class = ml.AdamW - elif cfg.hyperparameters.optimizer.lower() == "sgd": - optimizer = ml.SGD - elif cfg.hyperparameters.optimizer.lower() == "prodigy": - optimizer_class = ml.Prodigy - else: - raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}') - - params.update(cfg.hyperparameters.optimizer_params) - optimizer = optimizer_class( - [ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ], - **params, - ) - - # set up our LR scheduler here - - if not model._cfg.training: - optimizer = None - lr_scheduler = None - - # automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present - if not cfg.trainer.load_state_dict and cfg.trainer.backend == "deepspeed" and not (cfg.ckpt_dir / name / "latest").exists(): - print("DeepSpeed checkpoint missing, but weights found.") - cfg.trainer.load_state_dict = True - - stats = None - if cfg.trainer.load_state_dict or not model._cfg.training: - load_path = cfg.ckpt_dir / name / "fp32.pth" - state = torch.load(load_path, map_location=torch.device(cfg.device)) - - # state dict is not just the module, extract the extra trainer details - if "stats" in state: - stats = state["stats"] - - if "module" in state: - state = state["module"] - - model.load_state_dict(state, strict=cfg.trainer.strict_loading) - - # use base engine because DeepSpeed memory leaks if it's a non-training model - engines[name] = (Engine if model._cfg.training else _Engine)( - model=model, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - - _cfg=model._cfg, - stats=stats - ) - - engines = Engines(engines) - engines.setup() - - if not cfg.trainer.load_state_dict: - engines.load_checkpoint() - - # freeze requested params - for name, engine in engines.items(): - engine.freeze(freeze_all=False) - - do_gc() - - return engines class EvalFn(Protocol): def __call__(self, *, engines: Engines):