Add tag that can be applied to prevent parameter training

This commit is contained in:
James Betker 2020-10-06 20:39:49 -06:00
parent 2f2e3f33f8
commit 1e415b249b

View File

@ -192,7 +192,7 @@ class ExtensibleTrainer(BaseModel):
if net_enabled: if net_enabled:
enabled += 1 enabled += 1
for p in net.parameters(): for p in net.parameters():
if p.dtype != torch.int64 and p.dtype != torch.bool: if p.dtype != torch.int64 and p.dtype != torch.bool and not hasattr(p, "DO_NOT_TRAIN"):
p.requires_grad = net_enabled p.requires_grad = net_enabled
else: else:
p.requires_grad = False p.requires_grad = False