added option to specify parameters to freeze per-model in YAML (because I need to see about committing atrocities with convering an AR into an AR+NAR)

This commit is contained in:
mrq 2023-09-07 18:19:51 -05:00
parent c47fc3274e
commit 8837bc34d7
4 changed files with 28 additions and 12 deletions

View File

@ -164,6 +164,7 @@ class Model:
training: bool = True
interleave: bool = False
use_multiembedding: bool = True # nasty bandaid I got myself into
frozen_params: list[str] = field(default_factory=lambda: [])
@property
def full_name(self):

View File

@ -64,11 +64,15 @@ class Engine():
self.global_samples = 0
self.tokens_processed = 0
def freeze(self):
for p in self.module.parameters():
if p.requires_grad:
p.requires_grad_(False)
self._frozen_params.add(p)
def freeze(self, freeze_all=True):
# set to freeze
if self._cfg is None or not hasattr(self._cfg, "frozen_params"):
raise Exception("freeze_all=False yet self._cfg.frozen_params is None")
for name, param in self.module.named_parameters():
if (freeze_all and param.requires_grad) or (not freeze_all and name in self._cfg.frozen_params):
param.requires_grad_(False)
self._frozen_params.add(param)
def unfreeze(self):
for p in self._frozen_params:

View File

@ -31,6 +31,7 @@ if not distributed_initialized() and cfg.trainer.backend == "deepspeed":
class Engine(DeepSpeedEngine):
def __init__(self, *args, **kwargs):
self._cfg = None
if '_cfg' in kwargs:
self._cfg = kwargs['_cfg']
kwargs.pop("_cfg")
@ -43,15 +44,18 @@ class Engine(DeepSpeedEngine):
self.tokens_processed = 0
def freeze(self):
for p in self.module.parameters():
if p.requires_grad:
p.requires_grad_(False)
self._frozen_params.add(p)
def freeze(self, freeze_all=True):
if self._cfg is None or not hasattr(self._cfg, "frozen_params"):
raise Exception("freeze_all=False yet self._cfg.frozen_params is None")
for name, param in self.module.named_parameters():
if (freeze_all and param.requires_grad) or (not freeze_all and name in self._cfg.frozen_params):
param.requires_grad_(False)
self._frozen_params.add(param)
def unfreeze(self):
for p in self._frozen_params:
p.requires_grad_(True)
for param in self._frozen_params:
param.requires_grad_(True)
self._frozen_params.clear()
@property

View File

@ -110,6 +110,8 @@ def load_engines(invert=False):
# should decouple the following from this trainer script
# probably with passing a fun that defaults to a lambda x: x deal
"""
# can probably be done a lot more intelligently but oh well
# extend the proms_emb if we ever touch the n_prom_levels or n_prom_tokens (from adding tasks)
if model.proms_emb.weight.shape[0] > state['proms_emb.weight'].shape[0] or model.proms_emb.weight.shape[1] > state['proms_emb.weight'].shape[1]:
o_prom_levels, o_prom_tokens, d_model = state['proms_emb.weight'].shape
@ -128,6 +130,7 @@ def load_engines(invert=False):
model.resps_emb.weight.data[:o_resp_levels, :o_resp_tokens, :] = state['resps_emb.weight'].data[:o_resp_levels, :o_resp_tokens, :]
# copy the full tensors back
state['resps_emb.weight'] = model.resps_emb.weight
"""
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
@ -147,6 +150,10 @@ def load_engines(invert=False):
if not cfg.trainer.load_state_dict:
engines.load_checkpoint()
# freeze requested params
for name, engine in engines.items():
engine.freeze(freeze_all=False)
do_gc()
return engines