diff --git a/vall_e/config.py b/vall_e/config.py index f7b0153..f7cfc7b 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -202,7 +202,7 @@ class Model: @property # required for fp8 as the lengths needs to be divisible by 8 def input_alignment(self): - return 8 if cfg.fp8.enabled else 0 + return 8 if cfg.optimizations.fp8 else 0 @property def full_name(self): @@ -220,7 +220,7 @@ class Model: else: name.append(self.arch_type.replace("/", "-")) - if cfg.bitsandbytes.bitnet: + if cfg.optimizations.bitnet: name.append("bitnet") if self.interleave: @@ -521,9 +521,10 @@ class Inference: return torch.float8_e4m3fn return torch.float32 +# should be renamed to optimizations @dataclass() -class BitsAndBytes: - enabled: bool = False +class Optimizations: + bitsandbytes: bool = False injects: bool = False replace: bool = False @@ -531,11 +532,7 @@ class BitsAndBytes: embedding: bool = True bitnet: bool = False - -@dataclass() -class FP8: - enabled: bool = False - backend: str = "te" + fp8: bool = False @dataclass() class Config(_Config): @@ -550,11 +547,10 @@ class Config(_Config): evaluation: Evaluation = field(default_factory=lambda: Evaluation) trainer: Trainer = field(default_factory=lambda: Trainer) inference: Inference = field(default_factory=lambda: Inference) - bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes) + bitsandbytes: dict | list | None = None # deprecated + optimizations: Optimizations = field(default_factory=lambda: Optimizations) tokenizer: str = "./tokenizer.json" - - fp8: FP8 = field(default_factory=lambda: FP8) sample_rate: int = 24_000 variable_sample_rate: bool = True @@ -594,30 +590,31 @@ class Config(_Config): self.dataset.use_hdf5 = False def format( self ): - #if not isinstance(self.dataset, type): self.dataset = Dataset(**self.dataset) self.dataset.training = [ Path(dir) for dir in self.dataset.training ] self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ] self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ] - #if not isinstance(self.model, type): if self.models is not None: self.model = Model(**next(iter(self.models))) else: self.model = Model(**self.model) - #if not isinstance(self.hyperparameters, type): self.hyperparameters = Hyperparameters(**self.hyperparameters) - #if not isinstance(self.evaluation, type): + self.evaluation = Evaluation(**self.evaluation) - #if not isinstance(self.trainer, type): + self.trainer = Trainer(**self.trainer) + if not isinstance(self.trainer.deepspeed, type): self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed) - #if not isinstance(self.inference, type): + self.inference = Inference(**self.inference) - #if not isinstance(self.bitsandbytes, type): - self.bitsandbytes = BitsAndBytes(**self.bitsandbytes) + + if self.bitsandbytes is not None: + self.optimizations = Optimizations(**self.bitsandbytes) + else: + self.optimizations = Optimizations(**self.optimizations) # Preserves the old behavior class NaiveTokenizer: diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 31e2551..6ee4860 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -44,7 +44,7 @@ def load_engines(training=True): if inferencing: model._cfg.training = False - if (cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace) or (cfg.fp8.enabled): + if (cfg.optimizations.bitsandbytes and cfg.optimizations.replace) or (cfg.optimizations.fp8): model.model = ml.replace_linear( model.model ) if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer): diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index f84f520..99695e9 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -368,7 +368,7 @@ def example_usage(): 'n_layers': 12, # 32 'n_experts': 1, - 'l_padding': 8 if cfg.fp8.enabled else 0, + 'l_padding': 8 if cfg.optimizations.fp8 else 0, 'config': cfg.model } @@ -397,33 +397,7 @@ def example_usage(): engine = Engine(model=model, optimizer=optimizer) - # copy embeddings if requested - """ - if cfg.model._embeddings is not None: - embeddings_path = cfg.relpath / cfg.model._embeddings - - if embeddings_path.exists(): - embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device)) - if "module" in embeddings: - embeddings = embeddings["module"] - - frozen_params = set() - for k in list(embeddings.keys()): - if re.findall(r'_emb.', k): - frozen_params.add(k) - else: - del embeddings[k] - - engine.module.load_state_dict(embeddings, strict=False) - - for name, param in engine.module.named_parameters(): - if name not in frozen_params: - continue - param.requires_grad_(False) - engine._frozen_params.add(param) - """ - - if (cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace) or (cfg.fp8.enabled): + if (cfg.optimizations.bitsandbytes and cfg.optimizations.replace) or (cfg.optimizations.fp8): model.model = ml.replace_linear( model.model ) torch.save( { diff --git a/vall_e/utils/wrapper.py b/vall_e/utils/wrapper.py index 1741a8e..9d8af47 100755 --- a/vall_e/utils/wrapper.py +++ b/vall_e/utils/wrapper.py @@ -8,20 +8,20 @@ Embedding = torch.nn.Embedding Linear = torch.nn.Linear # https://github.com/kyegomez/BitNet -if cfg.bitsandbytes.bitnet: +if cfg.optimizations.bitnet: from bitnet import BitLinear -if cfg.bitsandbytes.enabled: +if cfg.optimizations.bitsandbytes: import bitsandbytes as bnb - if cfg.bitsandbytes.linear: + if cfg.optimizations.linear: - if cfg.bitsandbytes.bitnet: + if cfg.optimizations.bitnet: Linear = BitLinear else: Linear = bnb.nn.Linear8bitLt - if cfg.bitsandbytes.embedding: + if cfg.optimizations.embedding: Embedding = bnb.nn.modules.Embedding """ Embedding.forward = lambda self, input: ( self.norm(F.embedding( @@ -36,7 +36,7 @@ if cfg.bitsandbytes.enabled: """ -if cfg.bitsandbytes.enabled: +if cfg.optimizations.bitsandbytes: import bitsandbytes as bnb Adam = bnb.optim.Adam8bit @@ -77,7 +77,7 @@ def autocast_forward( func ): return wrapper Embedding.forward = autocast_forward(Embedding.forward) -if cfg.fp8.enabled: +if cfg.optimizations.fp8: import transformer_engine.pytorch as te Linear = te.Linear @@ -90,7 +90,7 @@ else: def autocast(): yield torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp) -if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled: +if cfg.optimizations.injects and cfg.optimizations.bitsandbytes: torch.nn.Linear = Linear torch.nn.Embedding = Embedding @@ -98,16 +98,17 @@ if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled: torch.optim.AdamW = AdamW torch.optim.SGD = SGD - # disgusting kludge, but it works (just realized BitNet has its own replacement routine) def replace_linear( model ): - bnb = cfg.bitsandbytes.enabled and cfg.bitsandbytes.linear and not cfg.bitsandbytes.bitnet + bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet + klass = Linear device = next(model.parameters()).device linears = [k.split('.') for k, m in model.named_modules() if isinstance(m, torch.nn.Linear)] for *parent, k in linears: name = '.'.join(parent) + # copy parameters m = getattr( model.get_submodule(name), k ) @@ -120,7 +121,7 @@ def replace_linear( model ): # overwrite setattr( model.get_submodule(name), k, - Linear( **kwargs ).to(device=device, dtype=cfg.trainer.dtype) + klass( **kwargs ).to(device=device, dtype=cfg.trainer.dtype) ) return model