added an option to allow injecting embeddings from another model, because it dawned upon me how valuable embeddings from a good model can be for subsequent trainings (defined under cfg.models._embeddings as a relative path to the yaml)

This commit is contained in:
mrq 2024-04-04 19:11:49 -05:00
parent 91062361af
commit 7075c2a5f0
4 changed files with 78 additions and 1 deletions

View File

@ -253,6 +253,7 @@ class Model:
class Models: class Models:
_max_levels: int = 0 _max_levels: int = 0
_prom_levels: int = 1 _prom_levels: int = 1
_embeddings: str | None = None
_models: list[Model] = field(default_factory=lambda: [ _models: list[Model] = field(default_factory=lambda: [
Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, langs=1, experts=1, training=True, interleave=False), Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, langs=1, experts=1, training=True, interleave=False),

View File

@ -115,6 +115,8 @@ def load_engines(training=True):
model.load_state_dict(state, strict=cfg.trainer.strict_loading) model.load_state_dict(state, strict=cfg.trainer.strict_loading)
# deepspeed inferencing # deepspeed inferencing
if backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"): if backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"):
engine_class = _Engine engine_class = _Engine
@ -140,6 +142,33 @@ def load_engines(training=True):
for name, engine in engines.items(): for name, engine in engines.items():
engine.freeze(freeze_all=False) engine.freeze(freeze_all=False)
# copy embeddings if requested
if cfg.models._embeddings is not None:
embeddings_path = cfg.relpath / cfg.models._embeddings
if embeddings_path.exists():
embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device))
if "module" in embeddings:
embeddings = embeddings["module"]
frozen_params = set()
for k in list(embeddings.keys()):
if re.findall(r'_emb\.', k):
frozen_params.add(k)
else:
del embeddings[k]
engine.module.load_state_dict(embeddings, strict=False)
# there's definitely a much better way but I can't be assed at the moment
for name, param in engine.module.named_parameters():
if name not in frozen_params:
continue
param.requires_grad_(False)
engine._frozen_params.add(param)
#do_gc() #do_gc()
return engines return engines

View File

@ -309,6 +309,7 @@ def example_usage():
from ..engines import Engine from ..engines import Engine
from tqdm import tqdm from tqdm import tqdm
from ..utils import wrapper as ml from ..utils import wrapper as ml
import re
device = "cuda" device = "cuda"
x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels) x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels)
@ -367,6 +368,30 @@ def example_usage():
#optimizer = ml.AdamW(model.parameters(), lr=1.0e-4) #optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
engine = Engine(model=model, optimizer=optimizer) engine = Engine(model=model, optimizer=optimizer)
# copy embeddings if requested
if cfg.models._embeddings is not None:
embeddings_path = cfg.relpath / cfg.models._embeddings
if embeddings_path.exists():
embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device))
if "module" in embeddings:
embeddings = embeddings["module"]
frozen_params = set()
for k in list(embeddings.keys()):
if re.findall(r'_emb\.', k):
frozen_params.add(k)
else:
del embeddings[k]
engine.module.load_state_dict(embeddings, strict=False)
for name, param in engine.module.named_parameters():
if name not in frozen_params:
continue
param.requires_grad_(False)
engine._frozen_params.add(param)
if cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace: if cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace:
model.model = ml.replace_linear( model.model ) model.model = ml.replace_linear( model.model )

View File

@ -3,6 +3,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import traceback import traceback
import numpy as np import numpy as np
import re
from typing import Literal, overload from typing import Literal, overload
from functools import partial from functools import partial
@ -41,13 +42,34 @@ except Exception as e:
pass pass
try: try:
from bitnet import BitNetTransformer from bitnet.bit_transformer import Transformer as BitNetTransformerBlock, RMSNorm as BitNetRMSNorm
class BitNetTransformer(nn.Module):
def __init__(
self,
dim: int,
depth: int,
num_tokens: int,
heads=8,
ff_mult=4,
):
super().__init__()
self.transformer = BitNetTransformerBlock( dim=dim, depth=depth, heads=heads, ff_mult=ff_mult )
self.norm = BitNetRMSNorm(dim)
def forward(self, x):
x = self.transformer(x)
return self.norm( x )
"""
from bitnet import BitNetTransformer
def NoEmbedding_BitNetTransformer_Forward(self, x): def NoEmbedding_BitNetTransformer_Forward(self, x):
x = self.transformer(x) x = self.transformer(x)
return self.to_logits[0](x) return self.to_logits[0](x)
BitNetTransformer.forward = NoEmbedding_BitNetTransformer_Forward BitNetTransformer.forward = NoEmbedding_BitNetTransformer_Forward
"""
except Exception as e: except Exception as e:
print("Error importing `bitnet` arch:", e) print("Error importing `bitnet` arch:", e)