diff --git a/codes/trainer/base_model.py b/codes/trainer/base_model.py index 56f33603..4ea7edae 100644 --- a/codes/trainer/base_model.py +++ b/codes/trainer/base_model.py @@ -129,7 +129,7 @@ class BaseModel(): network.load_state_dict(load_net_clean, strict=strict) - def consolidate_state(self, state): + def consolidate_state(self): for o in self.optimizers: if isinstance(o, ZeroRedundancyOptimizer): state['optimizers'].append(o.consolidate_state_dict(to=0))