diff --git a/codes/scripts/find_faulty_files.py b/codes/scripts/find_faulty_files.py index abbc431b..f27472be 100644 --- a/codes/scripts/find_faulty_files.py +++ b/codes/scripts/find_faulty_files.py @@ -85,6 +85,6 @@ if __name__ == "__main__": for i, data in enumerate(tqdm(dataloader)): current_batch = data model.feed_data(data, i) - model.optimize_parameters(i) + model.optimize_parameters(i, optimize=False) diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 1ff67c83..24e5fa72 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -183,7 +183,7 @@ class ExtensibleTrainer(BaseModel): if isinstance(v, torch.Tensor): self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=batch_factor, dim=0)] - def optimize_parameters(self, step): + def optimize_parameters(self, step, optimize=True): # Some models need to make parametric adjustments per-step. Do that here. for net in self.networks.values(): if hasattr(net.module, "update_for_step"): @@ -255,7 +255,7 @@ class ExtensibleTrainer(BaseModel): raise OverwrittenStateError(k, list(state.keys())) state[k] = v - if train_step: + if train_step and optimize: # And finally perform optimization. [e.before_optimize(state) for e in self.experiments] s.do_step(step)