Turn off optimization in find_faulty_files

This commit is contained in:
James Betker 2021-12-09 09:02:09 -07:00
parent a66a2bf91b
commit 32cfcf3684
2 changed files with 3 additions and 3 deletions

View File

@ -85,6 +85,6 @@ if __name__ == "__main__":
for i, data in enumerate(tqdm(dataloader)): for i, data in enumerate(tqdm(dataloader)):
current_batch = data current_batch = data
model.feed_data(data, i) model.feed_data(data, i)
model.optimize_parameters(i) model.optimize_parameters(i, optimize=False)

View File

@ -183,7 +183,7 @@ class ExtensibleTrainer(BaseModel):
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=batch_factor, dim=0)] self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=batch_factor, dim=0)]
def optimize_parameters(self, step): def optimize_parameters(self, step, optimize=True):
# Some models need to make parametric adjustments per-step. Do that here. # Some models need to make parametric adjustments per-step. Do that here.
for net in self.networks.values(): for net in self.networks.values():
if hasattr(net.module, "update_for_step"): if hasattr(net.module, "update_for_step"):
@ -255,7 +255,7 @@ class ExtensibleTrainer(BaseModel):
raise OverwrittenStateError(k, list(state.keys())) raise OverwrittenStateError(k, list(state.keys()))
state[k] = v state[k] = v
if train_step: if train_step and optimize:
# And finally perform optimization. # And finally perform optimization.
[e.before_optimize(state) for e in self.experiments] [e.before_optimize(state) for e in self.experiments]
s.do_step(step) s.do_step(step)