diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 9f3e8cb..8b933f4 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -143,11 +143,11 @@ def load_engines(training=True): del state[k] # resize text embedding - if model.config.text_tokens != state["text_emb.weight"].shape[0]: + if "text_emb.weight" in state and model.config.text_tokens != state["text_emb.weight"].shape[0]: state["text_emb.weight"] = state["text_emb.weight"][:model.config.text_tokens] # resize text embedding - if model.config.resp_levels != state["rvq_level_emb.weight"].shape[0]: + if "rvq_level_emb.weight" in state and model.config.resp_levels != state["rvq_level_emb.weight"].shape[0]: state["rvq_level_emb.weight"] = state["rvq_level_emb.weight"][:model.config.resp_levels] model.load_state_dict(state, strict=cfg.trainer.strict_loading) diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 70d0902..5e7e0c1 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -86,7 +86,7 @@ class Engine(): self._frozen_params.clear() @property - def _training(self): + def training(self): if not hasattr(self, "hyper_config"): return True return self.hyper_config.training @@ -308,7 +308,7 @@ class Engines(dict[str, Engine]): "userdata": userdata } if callback: - state_dict = callback( state_dict, engine.module ) + state_dict = callback( state_dict, engine.hyper_config ) torch.save(state_dict, outpath) print(f"Exported {name} to {outpath}") @@ -321,7 +321,7 @@ class Engines(dict[str, Engine]): cfg.ckpt_dir.mkdir(parents=True, exist_ok=True) for name, engine in self.items(): - if not engine._training: + if not engine.training: continue save_dir = cfg.ckpt_dir / name @@ -371,7 +371,7 @@ class Engines(dict[str, Engine]): def set_lr(self, lr): for engine in self.values(): - if not engine._training: + if not engine.training: continue engine.set_lr(lr) @@ -406,7 +406,7 @@ class Engines(dict[str, Engine]): do_gc() for name, engine in self.items(): - if not engine._training: + if not engine.training: continue device = engine.device diff --git a/vall_e/engines/deepspeed.py b/vall_e/engines/deepspeed.py index 08258ae..5ed93e4 100755 --- a/vall_e/engines/deepspeed.py +++ b/vall_e/engines/deepspeed.py @@ -77,7 +77,7 @@ class Engine(DeepSpeedEngine): self._frozen_params.clear() @property - def _training(self): + def training(self): return self.hyper_config.training @property diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index f3b99fc..0915f89 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -269,6 +269,8 @@ class AR_NAR(Base): resps_list=resps_list, lang_list=lang_list, tone_list=tone_list, + + quant_levels=torch.Tensor( [ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ] ).to( device=device, dtype=torch.int32 ), ) if recurrent_state is not None: diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index 9d4e64b..be32ab8 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -157,6 +157,9 @@ def train( # Training loop for batch in _make_infinite_epochs(train_dl): + if not engine.training: + continue + if engines.global_step >= cfg.trainer.iterations: break