From 1e415b249b7ba73a59e2dabe24189ff8127052fd Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 6 Oct 2020 20:39:49 -0600 Subject: [PATCH] Add tag that can be applied to prevent parameter training --- codes/models/ExtensibleTrainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index d17efcf4..663d60f9 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -192,7 +192,7 @@ class ExtensibleTrainer(BaseModel): if net_enabled: enabled += 1 for p in net.parameters(): - if p.dtype != torch.int64 and p.dtype != torch.bool: + if p.dtype != torch.int64 and p.dtype != torch.bool and not hasattr(p, "DO_NOT_TRAIN"): p.requires_grad = net_enabled else: p.requires_grad = False