diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index d17efcf4..663d60f9 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -192,7 +192,7 @@ class ExtensibleTrainer(BaseModel): if net_enabled: enabled += 1 for p in net.parameters(): - if p.dtype != torch.int64 and p.dtype != torch.bool: + if p.dtype != torch.int64 and p.dtype != torch.bool and not hasattr(p, "DO_NOT_TRAIN"): p.requires_grad = net_enabled else: p.requires_grad = False