Turn off optimization in find_faulty_files
This commit is contained in:
parent
a66a2bf91b
commit
32cfcf3684
|
@ -85,6 +85,6 @@ if __name__ == "__main__":
|
||||||
for i, data in enumerate(tqdm(dataloader)):
|
for i, data in enumerate(tqdm(dataloader)):
|
||||||
current_batch = data
|
current_batch = data
|
||||||
model.feed_data(data, i)
|
model.feed_data(data, i)
|
||||||
model.optimize_parameters(i)
|
model.optimize_parameters(i, optimize=False)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -183,7 +183,7 @@ class ExtensibleTrainer(BaseModel):
|
||||||
if isinstance(v, torch.Tensor):
|
if isinstance(v, torch.Tensor):
|
||||||
self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=batch_factor, dim=0)]
|
self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=batch_factor, dim=0)]
|
||||||
|
|
||||||
def optimize_parameters(self, step):
|
def optimize_parameters(self, step, optimize=True):
|
||||||
# Some models need to make parametric adjustments per-step. Do that here.
|
# Some models need to make parametric adjustments per-step. Do that here.
|
||||||
for net in self.networks.values():
|
for net in self.networks.values():
|
||||||
if hasattr(net.module, "update_for_step"):
|
if hasattr(net.module, "update_for_step"):
|
||||||
|
@ -255,7 +255,7 @@ class ExtensibleTrainer(BaseModel):
|
||||||
raise OverwrittenStateError(k, list(state.keys()))
|
raise OverwrittenStateError(k, list(state.keys()))
|
||||||
state[k] = v
|
state[k] = v
|
||||||
|
|
||||||
if train_step:
|
if train_step and optimize:
|
||||||
# And finally perform optimization.
|
# And finally perform optimization.
|
||||||
[e.before_optimize(state) for e in self.experiments]
|
[e.before_optimize(state) for e in self.experiments]
|
||||||
s.do_step(step)
|
s.do_step(step)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user