added option to specify parameters to freeze per-model in YAML (because I need to see about committing atrocities with convering an AR into an AR+NAR)
This commit is contained in:
parent
c47fc3274e
commit
8837bc34d7
|
@ -164,6 +164,7 @@ class Model:
|
|||
training: bool = True
|
||||
interleave: bool = False
|
||||
use_multiembedding: bool = True # nasty bandaid I got myself into
|
||||
frozen_params: list[str] = field(default_factory=lambda: [])
|
||||
|
||||
@property
|
||||
def full_name(self):
|
||||
|
|
|
@ -64,11 +64,15 @@ class Engine():
|
|||
self.global_samples = 0
|
||||
self.tokens_processed = 0
|
||||
|
||||
def freeze(self):
|
||||
for p in self.module.parameters():
|
||||
if p.requires_grad:
|
||||
p.requires_grad_(False)
|
||||
self._frozen_params.add(p)
|
||||
def freeze(self, freeze_all=True):
|
||||
# set to freeze
|
||||
if self._cfg is None or not hasattr(self._cfg, "frozen_params"):
|
||||
raise Exception("freeze_all=False yet self._cfg.frozen_params is None")
|
||||
|
||||
for name, param in self.module.named_parameters():
|
||||
if (freeze_all and param.requires_grad) or (not freeze_all and name in self._cfg.frozen_params):
|
||||
param.requires_grad_(False)
|
||||
self._frozen_params.add(param)
|
||||
|
||||
def unfreeze(self):
|
||||
for p in self._frozen_params:
|
||||
|
|
|
@ -31,6 +31,7 @@ if not distributed_initialized() and cfg.trainer.backend == "deepspeed":
|
|||
|
||||
class Engine(DeepSpeedEngine):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._cfg = None
|
||||
if '_cfg' in kwargs:
|
||||
self._cfg = kwargs['_cfg']
|
||||
kwargs.pop("_cfg")
|
||||
|
@ -43,15 +44,18 @@ class Engine(DeepSpeedEngine):
|
|||
|
||||
self.tokens_processed = 0
|
||||
|
||||
def freeze(self):
|
||||
for p in self.module.parameters():
|
||||
if p.requires_grad:
|
||||
p.requires_grad_(False)
|
||||
self._frozen_params.add(p)
|
||||
def freeze(self, freeze_all=True):
|
||||
if self._cfg is None or not hasattr(self._cfg, "frozen_params"):
|
||||
raise Exception("freeze_all=False yet self._cfg.frozen_params is None")
|
||||
|
||||
for name, param in self.module.named_parameters():
|
||||
if (freeze_all and param.requires_grad) or (not freeze_all and name in self._cfg.frozen_params):
|
||||
param.requires_grad_(False)
|
||||
self._frozen_params.add(param)
|
||||
|
||||
def unfreeze(self):
|
||||
for p in self._frozen_params:
|
||||
p.requires_grad_(True)
|
||||
for param in self._frozen_params:
|
||||
param.requires_grad_(True)
|
||||
self._frozen_params.clear()
|
||||
|
||||
@property
|
||||
|
|
|
@ -110,6 +110,8 @@ def load_engines(invert=False):
|
|||
# should decouple the following from this trainer script
|
||||
# probably with passing a fun that defaults to a lambda x: x deal
|
||||
|
||||
"""
|
||||
# can probably be done a lot more intelligently but oh well
|
||||
# extend the proms_emb if we ever touch the n_prom_levels or n_prom_tokens (from adding tasks)
|
||||
if model.proms_emb.weight.shape[0] > state['proms_emb.weight'].shape[0] or model.proms_emb.weight.shape[1] > state['proms_emb.weight'].shape[1]:
|
||||
o_prom_levels, o_prom_tokens, d_model = state['proms_emb.weight'].shape
|
||||
|
@ -128,6 +130,7 @@ def load_engines(invert=False):
|
|||
model.resps_emb.weight.data[:o_resp_levels, :o_resp_tokens, :] = state['resps_emb.weight'].data[:o_resp_levels, :o_resp_tokens, :]
|
||||
# copy the full tensors back
|
||||
state['resps_emb.weight'] = model.resps_emb.weight
|
||||
"""
|
||||
|
||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||
|
||||
|
@ -147,6 +150,10 @@ def load_engines(invert=False):
|
|||
if not cfg.trainer.load_state_dict:
|
||||
engines.load_checkpoint()
|
||||
|
||||
# freeze requested params
|
||||
for name, engine in engines.items():
|
||||
engine.freeze(freeze_all=False)
|
||||
|
||||
do_gc()
|
||||
|
||||
return engines
|
||||
|
|
Loading…
Reference in New Issue
Block a user