added an option to allow injecting embeddings from another model, because it dawned upon me how valuable embeddings from a good model can be for subsequent trainings (defined under cfg.models._embeddings as a relative path to the yaml)
This commit is contained in:
parent
91062361af
commit
7075c2a5f0
|
@ -253,6 +253,7 @@ class Model:
|
||||||
class Models:
|
class Models:
|
||||||
_max_levels: int = 0
|
_max_levels: int = 0
|
||||||
_prom_levels: int = 1
|
_prom_levels: int = 1
|
||||||
|
_embeddings: str | None = None
|
||||||
|
|
||||||
_models: list[Model] = field(default_factory=lambda: [
|
_models: list[Model] = field(default_factory=lambda: [
|
||||||
Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, langs=1, experts=1, training=True, interleave=False),
|
Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, langs=1, experts=1, training=True, interleave=False),
|
||||||
|
|
|
@ -115,6 +115,8 @@ def load_engines(training=True):
|
||||||
|
|
||||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# deepspeed inferencing
|
# deepspeed inferencing
|
||||||
if backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"):
|
if backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"):
|
||||||
engine_class = _Engine
|
engine_class = _Engine
|
||||||
|
@ -140,6 +142,33 @@ def load_engines(training=True):
|
||||||
for name, engine in engines.items():
|
for name, engine in engines.items():
|
||||||
engine.freeze(freeze_all=False)
|
engine.freeze(freeze_all=False)
|
||||||
|
|
||||||
|
# copy embeddings if requested
|
||||||
|
if cfg.models._embeddings is not None:
|
||||||
|
embeddings_path = cfg.relpath / cfg.models._embeddings
|
||||||
|
|
||||||
|
if embeddings_path.exists():
|
||||||
|
embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device))
|
||||||
|
if "module" in embeddings:
|
||||||
|
embeddings = embeddings["module"]
|
||||||
|
|
||||||
|
frozen_params = set()
|
||||||
|
|
||||||
|
for k in list(embeddings.keys()):
|
||||||
|
if re.findall(r'_emb\.', k):
|
||||||
|
frozen_params.add(k)
|
||||||
|
else:
|
||||||
|
del embeddings[k]
|
||||||
|
|
||||||
|
engine.module.load_state_dict(embeddings, strict=False)
|
||||||
|
|
||||||
|
# there's definitely a much better way but I can't be assed at the moment
|
||||||
|
for name, param in engine.module.named_parameters():
|
||||||
|
if name not in frozen_params:
|
||||||
|
continue
|
||||||
|
param.requires_grad_(False)
|
||||||
|
engine._frozen_params.add(param)
|
||||||
|
|
||||||
|
|
||||||
#do_gc()
|
#do_gc()
|
||||||
|
|
||||||
return engines
|
return engines
|
|
@ -309,6 +309,7 @@ def example_usage():
|
||||||
from ..engines import Engine
|
from ..engines import Engine
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from ..utils import wrapper as ml
|
from ..utils import wrapper as ml
|
||||||
|
import re
|
||||||
|
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels)
|
x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels)
|
||||||
|
@ -367,6 +368,30 @@ def example_usage():
|
||||||
#optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
|
#optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
|
||||||
engine = Engine(model=model, optimizer=optimizer)
|
engine = Engine(model=model, optimizer=optimizer)
|
||||||
|
|
||||||
|
# copy embeddings if requested
|
||||||
|
if cfg.models._embeddings is not None:
|
||||||
|
embeddings_path = cfg.relpath / cfg.models._embeddings
|
||||||
|
|
||||||
|
if embeddings_path.exists():
|
||||||
|
embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device))
|
||||||
|
if "module" in embeddings:
|
||||||
|
embeddings = embeddings["module"]
|
||||||
|
|
||||||
|
frozen_params = set()
|
||||||
|
for k in list(embeddings.keys()):
|
||||||
|
if re.findall(r'_emb\.', k):
|
||||||
|
frozen_params.add(k)
|
||||||
|
else:
|
||||||
|
del embeddings[k]
|
||||||
|
|
||||||
|
engine.module.load_state_dict(embeddings, strict=False)
|
||||||
|
|
||||||
|
for name, param in engine.module.named_parameters():
|
||||||
|
if name not in frozen_params:
|
||||||
|
continue
|
||||||
|
param.requires_grad_(False)
|
||||||
|
engine._frozen_params.add(param)
|
||||||
|
|
||||||
if cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace:
|
if cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace:
|
||||||
model.model = ml.replace_linear( model.model )
|
model.model = ml.replace_linear( model.model )
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import traceback
|
import traceback
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import re
|
||||||
|
|
||||||
from typing import Literal, overload
|
from typing import Literal, overload
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
@ -41,13 +42,34 @@ except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from bitnet import BitNetTransformer
|
from bitnet.bit_transformer import Transformer as BitNetTransformerBlock, RMSNorm as BitNetRMSNorm
|
||||||
|
|
||||||
|
class BitNetTransformer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
depth: int,
|
||||||
|
num_tokens: int,
|
||||||
|
heads=8,
|
||||||
|
ff_mult=4,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.transformer = BitNetTransformerBlock( dim=dim, depth=depth, heads=heads, ff_mult=ff_mult )
|
||||||
|
self.norm = BitNetRMSNorm(dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.transformer(x)
|
||||||
|
return self.norm( x )
|
||||||
|
|
||||||
|
"""
|
||||||
|
from bitnet import BitNetTransformer
|
||||||
def NoEmbedding_BitNetTransformer_Forward(self, x):
|
def NoEmbedding_BitNetTransformer_Forward(self, x):
|
||||||
x = self.transformer(x)
|
x = self.transformer(x)
|
||||||
return self.to_logits[0](x)
|
return self.to_logits[0](x)
|
||||||
|
|
||||||
BitNetTransformer.forward = NoEmbedding_BitNetTransformer_Forward
|
BitNetTransformer.forward = NoEmbedding_BitNetTransformer_Forward
|
||||||
|
"""
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error importing `bitnet` arch:", e)
|
print("Error importing `bitnet` arch:", e)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user