forked from mrq/DL-Art-School
29 lines
947 B
Python
29 lines
947 B
Python
|
|
||
|
|
||
|
def create_step(opt_step):
|
||
|
pass
|
||
|
|
||
|
|
||
|
# Defines the expected API for a step
|
||
|
class base_step:
|
||
|
# Returns all optimizers used in this step.
|
||
|
def get_optimizers(self):
|
||
|
pass
|
||
|
|
||
|
# Returns optimizers which are opting in for default LR scheduling.
|
||
|
def get_optimizers_with_default_scheduler(self):
|
||
|
pass
|
||
|
|
||
|
# Returns the names of the networks this step will train. Other networks will be frozen.
|
||
|
def get_networks_trained(self):
|
||
|
pass
|
||
|
|
||
|
# Performs all forward and backward passes for this step given an input state. All input states are lists or
|
||
|
# chunked tensors. Use grad_accum_step to derefernce these steps. Return the state with any variables the step
|
||
|
# exports (which may be used by subsequent steps)
|
||
|
def do_forward_backward(self, state, grad_accum_step):
|
||
|
return state
|
||
|
|
||
|
# Performs the optimizer step after all gradient accumulation is completed.
|
||
|
def do_step(self):
|
||
|
pass
|