diff --git a/vall_e/config.py b/vall_e/config.py index f81bccc..54c24e6 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -161,6 +161,7 @@ class Model: prom_levels: int = 8 tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc") arch_type: str = "transformer" + training: bool = True @property def scale(self): @@ -215,10 +216,11 @@ class Model: @dataclass() class Models: _max_levels: int = 0 + _prom_levels: int = 1 _models: list[Model] = field(default_factory=lambda: [ - Model(name="ar", resp_levels=1, prom_levels=8, tasks=8), - Model(name="nar", resp_levels=7, prom_levels=8, tasks=8), + Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, training=True), + Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, training=True), ]) def get(self, name=None): @@ -241,7 +243,7 @@ class Models: @property def prom_levels(self): - prom_levels = 1 + prom_levels = self._prom_levels for model in self._models: prom_levels = max(prom_levels, model.prom_levels) return prom_levels @@ -279,6 +281,8 @@ class Evaluation: ar_temperature: float = 1.0 nar_temperature: float = 0.2 + load_disabled_engines: bool = True + @dataclass() class DeepSpeed: zero_optimization_level: int = 0 @@ -407,6 +411,8 @@ class Trainer: aggressive_optimizations: bool = False check_for_oom: bool = True + load_disabled_engines: bool = False + gc_mode: str | None = None weight_dtype: str = "float16" diff --git a/vall_e/data.py b/vall_e/data.py index dd10c56..5432662 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -459,7 +459,7 @@ class Dataset(_Dataset): return dict( index=index, - path=path, + path=Path(path), spkr_name=spkr_name, spkr_id=spkr_id, task=task, diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index f0879ec..e7abae7 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -8,4 +8,4 @@ if cfg.trainer.backend == "deepspeed": elif cfg.trainer.backend == "local": from .base import Engine -from .base import Engines, TrainFeeder, default_feeder \ No newline at end of file +from .base import Engines, TrainFeeder, default_feeder, Engine as _Engine \ No newline at end of file diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 139628b..a697b83 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -74,6 +74,10 @@ class Engine(): p.requires_grad_(True) self._frozen_params.clear() + @property + def _training(self): + return self._cfg.training + @property def global_step(self): return self.global_steps @@ -82,7 +86,8 @@ class Engine(): def micro_step(self): return self.micro_steps - def train_batch_size(self): + @property + def batch_size(self): return cfg.hyperparameters.batch_size def gather_attribute(self, *args, **kwargs): @@ -137,7 +142,10 @@ class Engine(): def to(self, *args, **kwargs): self.module = self.module.to(*args, **kwargs) - return self.module + if self.optimizer: + self.optimizer = self.optimizer.to(*args, **kwargs) + + return self def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) @@ -199,9 +207,7 @@ class Engines(dict[str, Engine]): def setup(self): self._global_step = 0 self._micro_step = 0 - - for name, engine in self.items(): - engine.name = name + self._batch_size = 0 @property def global_step(self): @@ -211,6 +217,10 @@ class Engines(dict[str, Engine]): def micro_step(self): return self._micro_step + @property + def batch_size(self): + return self._batch_size + def gather_attribute(self, *args, **kwargs): ret = {} for engine in self.values(): @@ -242,6 +252,9 @@ class Engines(dict[str, Engine]): cfg.ckpt_dir.mkdir(parents=True, exist_ok=True) for name, engine in self.items(): + if not engine._training: + continue + save_dir = cfg.ckpt_dir / name try: engine.save_checkpoint(save_dir, tag=tag) @@ -282,25 +295,19 @@ class Engines(dict[str, Engine]): if cfg.hyperparameters.scheduler_type == "": self.set_lr(cfg.hyperparameters.learning_rate) - self._update_global_step() - self._update_micro_step() + self._update() def set_lr(self, lr): for engine in self.values(): + if not engine._training: + continue engine.set_lr(lr) - def _update_global_step(self): + def _update(self): for engine in self.values(): self._global_step = max(self._global_step, engine.global_step) - - def _update_micro_step(self): - for engine in self.values(): self._micro_step = max(self._micro_step, engine.micro_step) - - def train_batch_size(self): - batch_size = 0 - for engine in self.values(): - batch_size = max(batch_size, engine.train_batch_size()) + self._batch_size = max(self._batch_size, engine.batch_size) def eval(self): for engine in self.values(): @@ -325,8 +332,10 @@ class Engines(dict[str, Engine]): if cfg.trainer.gc_mode == 'step': do_gc() - for name, engine in self.items(): + if not engine._training: + continue + device = engine.device if cfg.trainer.gc_mode == 'substep': @@ -424,9 +433,9 @@ class Engines(dict[str, Engine]): ), ) - self._update_global_step() - self._update_micro_step() - stats["batch_size"] = self.train_batch_size() # len(batch["text"]) + self._update() + + stats["batch_size"] = self.batch_size stats["elapsed_time"] = total_elapsed_time stats["wall_time"] = time.time() stats["global_step"] = self.global_step diff --git a/vall_e/engines/deepspeed.py b/vall_e/engines/deepspeed.py index 0ca287c..84b5243 100755 --- a/vall_e/engines/deepspeed.py +++ b/vall_e/engines/deepspeed.py @@ -51,6 +51,10 @@ class Engine(DeepSpeedEngine): for p in self._frozen_params: p.requires_grad_(True) self._frozen_params.clear() + + @property + def _training(self): + return self._cfg.training @property def global_step(self): @@ -60,6 +64,10 @@ class Engine(DeepSpeedEngine): def micro_step(self): return self.micro_steps + @property + def batch_size(self): + return cfg.hyperparameters.batch_size + def gather_attribute(self, *args, **kwargs): return gather_attribute(self.module, *args, **kwargs) diff --git a/vall_e/train.py b/vall_e/train.py index b5ff43e..135406e 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -41,7 +41,7 @@ def train_feeder(engine, batch): return loss, stats @torch.inference_mode() -def run_eval(engines, eval_name, dl): +def run_eval(engines, disabled_engines, eval_name, dl): engines_stats = { 'eval': eval_name } @@ -51,11 +51,23 @@ def run_eval(engines, eval_name, dl): names = [] for name, engine in engines.items(): - names.append(name) if name[:2] == "ar": AR = engine elif name[:3] == "nar": NAR = engine + else: + continue + names.append(name) + + # hotload the missing models + for name, engine in disabled_engines.items(): + if AR is None and name[:2] == "ar": + AR = engine + elif NAR is None and name[:3] == "nar": + NAR = engine + else: + continue + names.append(name) stats = defaultdict(list) stats['loss'] = [] @@ -148,13 +160,18 @@ def main(): train_dl, subtrain_dl, val_dl = create_train_val_dataloader() def eval_fn(engines): + disabled_engines = load_engines(invert=True) if cfg.evaluation.load_disabled_engines else {} try: - run_eval(engines, "subtrain", subtrain_dl) - run_eval(engines, "val", val_dl) + run_eval(engines, disabled_engines, "subtrain", subtrain_dl) + run_eval(engines, disabled_engines, "val", val_dl) except Exception as e: print("Error occurred while performing eval:", str(e)) print(traceback.format_exc()) + if len(disabled_engines.keys()): + for name, engine in disabled_engines.items(): + engine = engine.to("cpu") + del disabled_engines qnt.unload_model() do_gc() diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index fef6546..b58e2dc 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -28,7 +28,7 @@ from .distributed import ( local_leader_only, ) -from ..engines import Engine, Engines, TrainFeeder, default_feeder +from ..engines import _Engine, Engine, Engines, TrainFeeder, default_feeder from ..models import get_models from .utils import to_device, do_gc @@ -36,36 +36,28 @@ from ..utils import wrapper as ml from ..data import get_phone_symmap # should decouple from this trainer script _logger = logging.getLogger(__name__) -_engines: Engines _command: str -def get_global_step(): - try: - return _engines.global_step - except: - return None - -def get_micro_step(): - try: - return _engines.micro_step - except: - return None - -def get_cmd(): - try: - return _command - except: - raise RuntimeError("Trainer has not been setup. Have you called trainer.train?") - - -get_iteration = get_global_step - -def load_engines(): +def load_engines(invert=False): models = get_models(cfg.models.get()) engines = dict() - for name in models: - model = models[name] + for name, model in models.items(): + # load only the models for training initially + # loads disabled models at evaluation time (to load updated weights if training separately) + # I'm sure there's a more elegant solution to this + if cfg.evaluation.load_disabled_engines: + if not invert and not model._cfg.training: + continue + if invert and model._cfg.training: + continue + # load only the models for training initially + # if load_disabled_engines, then models not marked for training will be loaded but ignored + # DeepSpeed has some weird quirks where loading an engine and moving it to CPU will have a memory leak or something + # I recommend not using this pathway + elif not cfg.trainer.load_disabled_engines: + if model._cfg.training: + continue optimizer = None lr_scheduler = None @@ -82,7 +74,11 @@ def load_engines(): weight_decay=0.01, ) - if cfg.trainer.load_state_dict: + if not model._cfg.training: + optimizer = None + lr_scheduler = None + + if cfg.trainer.load_state_dict or not model._cfg.training: load_path = cfg.ckpt_dir / name / "fp32.pth" state = torch.load(load_path) # exporting the model from the zero_to_fp32.py exports the actual module's dict @@ -109,18 +105,18 @@ def load_engines(): # copy weights from the dict into the old portion model.resps_emb.weight.data[:o_resp_levels, :o_resp_tokens, :] = state['resps_emb.weight'].data[:o_resp_levels, :o_resp_tokens, :] - # reuse additional levels, probably bad - for n in range(o_resp_tokens, n_resp_tokens): - model.resps_emb.weight.data[n] = model.resps_emb.weight.data[o_resp_tokens-1] # copy the full tensors back state['resps_emb.weight'] = model.resps_emb.weight model.load_state_dict(state, strict=cfg.trainer.strict_loading) - engines[name] = Engine( + # use base engine because DeepSpeed memory leaks + engines[name] = (Engine if model._cfg.training else _Engine)( + #engines[name] = Engine( model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, + _cfg=model._cfg, ) @@ -130,6 +126,8 @@ def load_engines(): if not cfg.trainer.load_state_dict: engines.load_checkpoint() + do_gc() + return engines class EvalFn(Protocol): diff --git a/vall_e/utils/utils.py b/vall_e/utils/utils.py index 988f595..86214d4 100755 --- a/vall_e/utils/utils.py +++ b/vall_e/utils/utils.py @@ -17,12 +17,21 @@ from torch import Tensor, nn from tqdm.auto import tqdm from typing import Callable, TypeVar, overload +try: + from deepspeed.runtime.utils import empty_cache +except Exception as e: + print(str(e)) + def empty_cache(): + ... + T = TypeVar("T") def do_gc(): gc.collect() torch.cuda.empty_cache() + empty_cache() + def flatten_dict(d): records = pd.json_normalize(d).to_dict(orient="records") return records[0] if records else {}