added an option to allow injecting embeddings from another model, because it dawned upon me how valuable embeddings from a good model can be for subsequent trainings (defined under cfg.models._embeddings as a relative path to the yaml)

master
mrq 2024-04-04 19:11:49 +07:00
parent 91062361af
commit 7075c2a5f0
4 changed files with 78 additions and 1 deletions

@ -253,6 +253,7 @@ class Model:
class Models:
_max_levels: int = 0
_prom_levels: int = 1
_embeddings: str | None = None
_models: list[Model] = field(default_factory=lambda: [
Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, langs=1, experts=1, training=True, interleave=False),

@ -115,6 +115,8 @@ def load_engines(training=True):
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
# deepspeed inferencing
if backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"):
engine_class = _Engine
@ -140,6 +142,33 @@ def load_engines(training=True):
for name, engine in engines.items():
engine.freeze(freeze_all=False)
# copy embeddings if requested
if cfg.models._embeddings is not None:
embeddings_path = cfg.relpath / cfg.models._embeddings
if embeddings_path.exists():
embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device))
if "module" in embeddings:
embeddings = embeddings["module"]
frozen_params = set()
for k in list(embeddings.keys()):
if re.findall(r'_emb\.', k):
frozen_params.add(k)
else:
del embeddings[k]
engine.module.load_state_dict(embeddings, strict=False)
# there's definitely a much better way but I can't be assed at the moment
for name, param in engine.module.named_parameters():
if name not in frozen_params:
continue
param.requires_grad_(False)
engine._frozen_params.add(param)
#do_gc()
return engines

@ -309,6 +309,7 @@ def example_usage():
from ..engines import Engine
from tqdm import tqdm
from ..utils import wrapper as ml
import re
device = "cuda"
x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels)
@ -367,6 +368,30 @@ def example_usage():
#optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
engine = Engine(model=model, optimizer=optimizer)
# copy embeddings if requested
if cfg.models._embeddings is not None:
embeddings_path = cfg.relpath / cfg.models._embeddings
if embeddings_path.exists():
embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device))
if "module" in embeddings:
embeddings = embeddings["module"]
frozen_params = set()
for k in list(embeddings.keys()):
if re.findall(r'_emb\.', k):
frozen_params.add(k)
else:
del embeddings[k]
engine.module.load_state_dict(embeddings, strict=False)
for name, param in engine.module.named_parameters():
if name not in frozen_params:
continue
param.requires_grad_(False)
engine._frozen_params.add(param)
if cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace:
model.model = ml.replace_linear( model.model )

@ -3,6 +3,7 @@ import torch
import torch.nn.functional as F
import traceback
import numpy as np
import re
from typing import Literal, overload
from functools import partial
@ -41,13 +42,34 @@ except Exception as e:
pass
try:
from bitnet import BitNetTransformer
from bitnet.bit_transformer import Transformer as BitNetTransformerBlock, RMSNorm as BitNetRMSNorm
class BitNetTransformer(nn.Module):
def __init__(
self,
dim: int,
depth: int,
num_tokens: int,
heads=8,
ff_mult=4,
):
super().__init__()
self.transformer = BitNetTransformerBlock( dim=dim, depth=depth, heads=heads, ff_mult=ff_mult )
self.norm = BitNetRMSNorm(dim)
def forward(self, x):
x = self.transformer(x)
return self.norm( x )
"""
from bitnet import BitNetTransformer
def NoEmbedding_BitNetTransformer_Forward(self, x):
x = self.transformer(x)
return self.to_logits[0](x)
BitNetTransformer.forward = NoEmbedding_BitNetTransformer_Forward
"""
except Exception as e:
print("Error importing `bitnet` arch:", e)