diff --git a/vall_e/config.py b/vall_e/config.py index e4df91c..383a93b 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -253,6 +253,7 @@ class Model: class Models: _max_levels: int = 0 _prom_levels: int = 1 + _embeddings: str | None = None _models: list[Model] = field(default_factory=lambda: [ Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, langs=1, experts=1, training=True, interleave=False), diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 0a0eb05..9622abb 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -115,6 +115,8 @@ def load_engines(training=True): model.load_state_dict(state, strict=cfg.trainer.strict_loading) + + # deepspeed inferencing if backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"): engine_class = _Engine @@ -140,6 +142,33 @@ def load_engines(training=True): for name, engine in engines.items(): engine.freeze(freeze_all=False) + # copy embeddings if requested + if cfg.models._embeddings is not None: + embeddings_path = cfg.relpath / cfg.models._embeddings + + if embeddings_path.exists(): + embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device)) + if "module" in embeddings: + embeddings = embeddings["module"] + + frozen_params = set() + + for k in list(embeddings.keys()): + if re.findall(r'_emb\.', k): + frozen_params.add(k) + else: + del embeddings[k] + + engine.module.load_state_dict(embeddings, strict=False) + + # there's definitely a much better way but I can't be assed at the moment + for name, param in engine.module.named_parameters(): + if name not in frozen_params: + continue + param.requires_grad_(False) + engine._frozen_params.add(param) + + #do_gc() return engines \ No newline at end of file diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 0984330..3691eca 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -309,6 +309,7 @@ def example_usage(): from ..engines import Engine from tqdm import tqdm from ..utils import wrapper as ml + import re device = "cuda" x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels) @@ -367,6 +368,30 @@ def example_usage(): #optimizer = ml.AdamW(model.parameters(), lr=1.0e-4) engine = Engine(model=model, optimizer=optimizer) + # copy embeddings if requested + if cfg.models._embeddings is not None: + embeddings_path = cfg.relpath / cfg.models._embeddings + + if embeddings_path.exists(): + embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device)) + if "module" in embeddings: + embeddings = embeddings["module"] + + frozen_params = set() + for k in list(embeddings.keys()): + if re.findall(r'_emb\.', k): + frozen_params.add(k) + else: + del embeddings[k] + + engine.module.load_state_dict(embeddings, strict=False) + + for name, param in engine.module.named_parameters(): + if name not in frozen_params: + continue + param.requires_grad_(False) + engine._frozen_params.add(param) + if cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace: model.model = ml.replace_linear( model.model ) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 26eb385..e008d91 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -3,6 +3,7 @@ import torch import torch.nn.functional as F import traceback import numpy as np +import re from typing import Literal, overload from functools import partial @@ -41,13 +42,34 @@ except Exception as e: pass try: - from bitnet import BitNetTransformer + from bitnet.bit_transformer import Transformer as BitNetTransformerBlock, RMSNorm as BitNetRMSNorm + class BitNetTransformer(nn.Module): + def __init__( + self, + dim: int, + depth: int, + num_tokens: int, + heads=8, + ff_mult=4, + ): + super().__init__() + + self.transformer = BitNetTransformerBlock( dim=dim, depth=depth, heads=heads, ff_mult=ff_mult ) + self.norm = BitNetRMSNorm(dim) + + def forward(self, x): + x = self.transformer(x) + return self.norm( x ) + + """ + from bitnet import BitNetTransformer def NoEmbedding_BitNetTransformer_Forward(self, x): x = self.transformer(x) return self.to_logits[0](x) BitNetTransformer.forward = NoEmbedding_BitNetTransformer_Forward + """ except Exception as e: print("Error importing `bitnet` arch:", e)