Support functionality
This commit is contained in:
parent
6af5d129ce
commit
6873ad6660
|
@ -259,7 +259,7 @@ class EMAWrapper(nn.Module):
|
|||
new_buffer_value = self.ema_updater.update_average(ma_buffer, current_buffer)
|
||||
ma_buffer.copy_(new_buffer_value)
|
||||
|
||||
def custom_optimizer_step(self, step):
|
||||
def after_step(self, step):
|
||||
if step % self.steps_per_ema == 0:
|
||||
self.update_moving_average()
|
||||
if step % self.steps_per_reset and step < self.steps_after_no_reset:
|
||||
|
|
|
@ -360,12 +360,17 @@ class ExtensibleTrainer(BaseModel):
|
|||
|
||||
def consume_gradients(self, state, step, it):
|
||||
[e.before_optimize(state) for e in self.experiments]
|
||||
# Call into pre-step hooks.
|
||||
for name, net in self.networks.items():
|
||||
if hasattr(net.module, "before_step"):
|
||||
net.module.before_step(it)
|
||||
|
||||
step.do_step(it)
|
||||
|
||||
# Call into custom step hooks as well as update EMA params.
|
||||
for name, net in self.networks.items():
|
||||
if hasattr(net, "custom_optimizer_step"):
|
||||
net.custom_optimizer_step(it)
|
||||
if hasattr(net.module, "after_step"):
|
||||
net.module.after_step(it)
|
||||
if self.do_emas:
|
||||
ema_params = self.emas[name].parameters()
|
||||
net_params = net.parameters()
|
||||
|
|
|
@ -27,7 +27,6 @@ class ConfigurableStep(Module):
|
|||
self.optimizers = None
|
||||
self.scaler = GradScaler(enabled=self.opt['fp16'] or opt_get(self.opt, ['grad_scaler_enabled'], False))
|
||||
self.grads_generated = False
|
||||
self.min_total_loss = opt_step['min_total_loss'] if 'min_total_loss' in opt_step.keys() else -999999999
|
||||
self.clip_grad_eps = opt_get(opt_step, ['clip_grad_eps'], None)
|
||||
|
||||
# This is a half-measure that can be used between anomaly_detection and running a potentially problematic
|
||||
|
@ -267,22 +266,12 @@ class ConfigurableStep(Module):
|
|||
# In some cases, the loss could not be set (e.g. all losses have 'after')
|
||||
if train and isinstance(total_loss, torch.Tensor) and total_loss.isfinite():
|
||||
loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss)
|
||||
reset_required = total_loss < self.min_total_loss
|
||||
|
||||
# Scale the loss down by the accumulation factor.
|
||||
total_loss = total_loss / self.env['mega_batch_factor']
|
||||
|
||||
# Get dem grads!
|
||||
self.scaler.scale(total_loss).backward()
|
||||
|
||||
if reset_required:
|
||||
# You might be scratching your head at this. Why would you zero grad as opposed to not doing a
|
||||
# backwards? Because DDP uses the backward() pass as a synchronization point and there is not a good
|
||||
# way to simply bypass backward. If you want a more efficient way to specify a min_loss, use or
|
||||
# implement it at the loss level.
|
||||
self.get_network_for_name(self.step_opt['training']).zero_grad()
|
||||
loss_accumulator.increment_metric("%s_skipped_steps" % (self.get_training_network_name(),))
|
||||
|
||||
self.grads_generated = True
|
||||
# Reset nan_loss_counter
|
||||
self.nan_loss_counter = 0
|
||||
|
|
Loading…
Reference in New Issue
Block a user