forked from mrq/DL-Art-School
Turn off optimization in find_faulty_files
This commit is contained in:
parent
a66a2bf91b
commit
32cfcf3684
|
@ -85,6 +85,6 @@ if __name__ == "__main__":
|
|||
for i, data in enumerate(tqdm(dataloader)):
|
||||
current_batch = data
|
||||
model.feed_data(data, i)
|
||||
model.optimize_parameters(i)
|
||||
model.optimize_parameters(i, optimize=False)
|
||||
|
||||
|
||||
|
|
|
@ -183,7 +183,7 @@ class ExtensibleTrainer(BaseModel):
|
|||
if isinstance(v, torch.Tensor):
|
||||
self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=batch_factor, dim=0)]
|
||||
|
||||
def optimize_parameters(self, step):
|
||||
def optimize_parameters(self, step, optimize=True):
|
||||
# Some models need to make parametric adjustments per-step. Do that here.
|
||||
for net in self.networks.values():
|
||||
if hasattr(net.module, "update_for_step"):
|
||||
|
@ -255,7 +255,7 @@ class ExtensibleTrainer(BaseModel):
|
|||
raise OverwrittenStateError(k, list(state.keys()))
|
||||
state[k] = v
|
||||
|
||||
if train_step:
|
||||
if train_step and optimize:
|
||||
# And finally perform optimization.
|
||||
[e.before_optimize(state) for e in self.experiments]
|
||||
s.do_step(step)
|
||||
|
|
Loading…
Reference in New Issue
Block a user