From 32cfcf36842dbdc34296e37287f26703a4c1b905 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 9 Dec 2021 09:02:09 -0700 Subject: [PATCH] Turn off optimization in find_faulty_files --- codes/scripts/find_faulty_files.py | 2 +- codes/trainer/ExtensibleTrainer.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/codes/scripts/find_faulty_files.py b/codes/scripts/find_faulty_files.py index abbc431b..f27472be 100644 --- a/codes/scripts/find_faulty_files.py +++ b/codes/scripts/find_faulty_files.py @@ -85,6 +85,6 @@ if __name__ == "__main__": for i, data in enumerate(tqdm(dataloader)): current_batch = data model.feed_data(data, i) - model.optimize_parameters(i) + model.optimize_parameters(i, optimize=False) diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 1ff67c83..24e5fa72 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -183,7 +183,7 @@ class ExtensibleTrainer(BaseModel): if isinstance(v, torch.Tensor): self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=batch_factor, dim=0)] - def optimize_parameters(self, step): + def optimize_parameters(self, step, optimize=True): # Some models need to make parametric adjustments per-step. Do that here. for net in self.networks.values(): if hasattr(net.module, "update_for_step"): @@ -255,7 +255,7 @@ class ExtensibleTrainer(BaseModel): raise OverwrittenStateError(k, list(state.keys())) state[k] = v - if train_step: + if train_step and optimize: # And finally perform optimization. [e.before_optimize(state) for e in self.experiments] s.do_step(step)