diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 6ba9cf63..a1c26053 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -123,7 +123,12 @@ class ExtensibleTrainer(BaseModel): dnets = [] all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()] for anet in all_networks: - if opt['dist']: + has_any_trainable_params = False + for p in anet.parameters(): + if not hasattr(p, 'DO_NOT_TRAIN'): + has_any_trainable_params = True + break + if has_any_trainable_params and opt['dist']: if opt['dist_backend'] == 'apex': # Use Apex to enable delay_allreduce, which is compatible with gradient checkpointing. from apex.parallel import DistributedDataParallel