added option to specify parameters to freeze per-model in YAML (because I need to see about committing atrocities with convering an AR into an AR+NAR)

This commit is contained in:
mrq 2023-09-07 18:19:51 -05:00
parent c47fc3274e
commit 8837bc34d7
4 changed files with 28 additions and 12 deletions

View File

@ -164,6 +164,7 @@ class Model:
training: bool = True training: bool = True
interleave: bool = False interleave: bool = False
use_multiembedding: bool = True # nasty bandaid I got myself into use_multiembedding: bool = True # nasty bandaid I got myself into
frozen_params: list[str] = field(default_factory=lambda: [])
@property @property
def full_name(self): def full_name(self):

View File

@ -64,11 +64,15 @@ class Engine():
self.global_samples = 0 self.global_samples = 0
self.tokens_processed = 0 self.tokens_processed = 0
def freeze(self): def freeze(self, freeze_all=True):
for p in self.module.parameters(): # set to freeze
if p.requires_grad: if self._cfg is None or not hasattr(self._cfg, "frozen_params"):
p.requires_grad_(False) raise Exception("freeze_all=False yet self._cfg.frozen_params is None")
self._frozen_params.add(p)
for name, param in self.module.named_parameters():
if (freeze_all and param.requires_grad) or (not freeze_all and name in self._cfg.frozen_params):
param.requires_grad_(False)
self._frozen_params.add(param)
def unfreeze(self): def unfreeze(self):
for p in self._frozen_params: for p in self._frozen_params:

View File

@ -31,6 +31,7 @@ if not distributed_initialized() and cfg.trainer.backend == "deepspeed":
class Engine(DeepSpeedEngine): class Engine(DeepSpeedEngine):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self._cfg = None
if '_cfg' in kwargs: if '_cfg' in kwargs:
self._cfg = kwargs['_cfg'] self._cfg = kwargs['_cfg']
kwargs.pop("_cfg") kwargs.pop("_cfg")
@ -43,15 +44,18 @@ class Engine(DeepSpeedEngine):
self.tokens_processed = 0 self.tokens_processed = 0
def freeze(self): def freeze(self, freeze_all=True):
for p in self.module.parameters(): if self._cfg is None or not hasattr(self._cfg, "frozen_params"):
if p.requires_grad: raise Exception("freeze_all=False yet self._cfg.frozen_params is None")
p.requires_grad_(False)
self._frozen_params.add(p) for name, param in self.module.named_parameters():
if (freeze_all and param.requires_grad) or (not freeze_all and name in self._cfg.frozen_params):
param.requires_grad_(False)
self._frozen_params.add(param)
def unfreeze(self): def unfreeze(self):
for p in self._frozen_params: for param in self._frozen_params:
p.requires_grad_(True) param.requires_grad_(True)
self._frozen_params.clear() self._frozen_params.clear()
@property @property

View File

@ -110,6 +110,8 @@ def load_engines(invert=False):
# should decouple the following from this trainer script # should decouple the following from this trainer script
# probably with passing a fun that defaults to a lambda x: x deal # probably with passing a fun that defaults to a lambda x: x deal
"""
# can probably be done a lot more intelligently but oh well
# extend the proms_emb if we ever touch the n_prom_levels or n_prom_tokens (from adding tasks) # extend the proms_emb if we ever touch the n_prom_levels or n_prom_tokens (from adding tasks)
if model.proms_emb.weight.shape[0] > state['proms_emb.weight'].shape[0] or model.proms_emb.weight.shape[1] > state['proms_emb.weight'].shape[1]: if model.proms_emb.weight.shape[0] > state['proms_emb.weight'].shape[0] or model.proms_emb.weight.shape[1] > state['proms_emb.weight'].shape[1]:
o_prom_levels, o_prom_tokens, d_model = state['proms_emb.weight'].shape o_prom_levels, o_prom_tokens, d_model = state['proms_emb.weight'].shape
@ -128,6 +130,7 @@ def load_engines(invert=False):
model.resps_emb.weight.data[:o_resp_levels, :o_resp_tokens, :] = state['resps_emb.weight'].data[:o_resp_levels, :o_resp_tokens, :] model.resps_emb.weight.data[:o_resp_levels, :o_resp_tokens, :] = state['resps_emb.weight'].data[:o_resp_levels, :o_resp_tokens, :]
# copy the full tensors back # copy the full tensors back
state['resps_emb.weight'] = model.resps_emb.weight state['resps_emb.weight'] = model.resps_emb.weight
"""
model.load_state_dict(state, strict=cfg.trainer.strict_loading) model.load_state_dict(state, strict=cfg.trainer.strict_loading)
@ -147,6 +150,10 @@ def load_engines(invert=False):
if not cfg.trainer.load_state_dict: if not cfg.trainer.load_state_dict:
engines.load_checkpoint() engines.load_checkpoint()
# freeze requested params
for name, engine in engines.items():
engine.freeze(freeze_all=False)
do_gc() do_gc()
return engines return engines