added an option to allow injecting embeddings from another model, because it dawned upon me how valuable embeddings from a good model can be for subsequent trainings (defined under cfg.models._embeddings as a relative path to the yaml)
This commit is contained in:
parent
91062361af
commit
7075c2a5f0
|
@ -253,6 +253,7 @@ class Model:
|
|||
class Models:
|
||||
_max_levels: int = 0
|
||||
_prom_levels: int = 1
|
||||
_embeddings: str | None = None
|
||||
|
||||
_models: list[Model] = field(default_factory=lambda: [
|
||||
Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, langs=1, experts=1, training=True, interleave=False),
|
||||
|
|
|
@ -115,6 +115,8 @@ def load_engines(training=True):
|
|||
|
||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||
|
||||
|
||||
|
||||
# deepspeed inferencing
|
||||
if backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"):
|
||||
engine_class = _Engine
|
||||
|
@ -140,6 +142,33 @@ def load_engines(training=True):
|
|||
for name, engine in engines.items():
|
||||
engine.freeze(freeze_all=False)
|
||||
|
||||
# copy embeddings if requested
|
||||
if cfg.models._embeddings is not None:
|
||||
embeddings_path = cfg.relpath / cfg.models._embeddings
|
||||
|
||||
if embeddings_path.exists():
|
||||
embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device))
|
||||
if "module" in embeddings:
|
||||
embeddings = embeddings["module"]
|
||||
|
||||
frozen_params = set()
|
||||
|
||||
for k in list(embeddings.keys()):
|
||||
if re.findall(r'_emb\.', k):
|
||||
frozen_params.add(k)
|
||||
else:
|
||||
del embeddings[k]
|
||||
|
||||
engine.module.load_state_dict(embeddings, strict=False)
|
||||
|
||||
# there's definitely a much better way but I can't be assed at the moment
|
||||
for name, param in engine.module.named_parameters():
|
||||
if name not in frozen_params:
|
||||
continue
|
||||
param.requires_grad_(False)
|
||||
engine._frozen_params.add(param)
|
||||
|
||||
|
||||
#do_gc()
|
||||
|
||||
return engines
|
|
@ -309,6 +309,7 @@ def example_usage():
|
|||
from ..engines import Engine
|
||||
from tqdm import tqdm
|
||||
from ..utils import wrapper as ml
|
||||
import re
|
||||
|
||||
device = "cuda"
|
||||
x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels)
|
||||
|
@ -367,6 +368,30 @@ def example_usage():
|
|||
#optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
|
||||
engine = Engine(model=model, optimizer=optimizer)
|
||||
|
||||
# copy embeddings if requested
|
||||
if cfg.models._embeddings is not None:
|
||||
embeddings_path = cfg.relpath / cfg.models._embeddings
|
||||
|
||||
if embeddings_path.exists():
|
||||
embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device))
|
||||
if "module" in embeddings:
|
||||
embeddings = embeddings["module"]
|
||||
|
||||
frozen_params = set()
|
||||
for k in list(embeddings.keys()):
|
||||
if re.findall(r'_emb\.', k):
|
||||
frozen_params.add(k)
|
||||
else:
|
||||
del embeddings[k]
|
||||
|
||||
engine.module.load_state_dict(embeddings, strict=False)
|
||||
|
||||
for name, param in engine.module.named_parameters():
|
||||
if name not in frozen_params:
|
||||
continue
|
||||
param.requires_grad_(False)
|
||||
engine._frozen_params.add(param)
|
||||
|
||||
if cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace:
|
||||
model.model = ml.replace_linear( model.model )
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
import traceback
|
||||
import numpy as np
|
||||
import re
|
||||
|
||||
from typing import Literal, overload
|
||||
from functools import partial
|
||||
|
@ -41,13 +42,34 @@ except Exception as e:
|
|||
pass
|
||||
|
||||
try:
|
||||
from bitnet import BitNetTransformer
|
||||
from bitnet.bit_transformer import Transformer as BitNetTransformerBlock, RMSNorm as BitNetRMSNorm
|
||||
|
||||
class BitNetTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
depth: int,
|
||||
num_tokens: int,
|
||||
heads=8,
|
||||
ff_mult=4,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.transformer = BitNetTransformerBlock( dim=dim, depth=depth, heads=heads, ff_mult=ff_mult )
|
||||
self.norm = BitNetRMSNorm(dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.transformer(x)
|
||||
return self.norm( x )
|
||||
|
||||
"""
|
||||
from bitnet import BitNetTransformer
|
||||
def NoEmbedding_BitNetTransformer_Forward(self, x):
|
||||
x = self.transformer(x)
|
||||
return self.to_logits[0](x)
|
||||
|
||||
BitNetTransformer.forward = NoEmbedding_BitNetTransformer_Forward
|
||||
"""
|
||||
|
||||
except Exception as e:
|
||||
print("Error importing `bitnet` arch:", e)
|
||||
|
|
Loading…
Reference in New Issue
Block a user