From 34001ad76598146ddc36432257246bcda671d312 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 18 Feb 2022 18:52:33 -0700 Subject: [PATCH] et --- codes/trainer/ExtensibleTrainer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 6ba9cf63..a1c26053 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -123,7 +123,12 @@ class ExtensibleTrainer(BaseModel): dnets = [] all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()] for anet in all_networks: - if opt['dist']: + has_any_trainable_params = False + for p in anet.parameters(): + if not hasattr(p, 'DO_NOT_TRAIN'): + has_any_trainable_params = True + break + if has_any_trainable_params and opt['dist']: if opt['dist_backend'] == 'apex': # Use Apex to enable delay_allreduce, which is compatible with gradient checkpointing. from apex.parallel import DistributedDataParallel