forked from mrq/DL-Art-School
46 lines
1.7 KiB
Python
46 lines
1.7 KiB
Python
|
# Defines the expected API for a step
|
||
|
class SrGanGeneratorStep:
|
||
|
|
||
|
def __init__(self, opt_step, opt, netsG, netsD):
|
||
|
self.step_opt = opt_step
|
||
|
self.opt = opt
|
||
|
self.gen = netsG['base']
|
||
|
self.disc = netsD['base']
|
||
|
for loss in self.step_opt['losses']:
|
||
|
|
||
|
# G pixel loss
|
||
|
if train_opt['pixel_weight'] > 0:
|
||
|
l_pix_type = train_opt['pixel_criterion']
|
||
|
if l_pix_type == 'l1':
|
||
|
self.cri_pix = nn.L1Loss().to(self.device)
|
||
|
elif l_pix_type == 'l2':
|
||
|
self.cri_pix = nn.MSELoss().to(self.device)
|
||
|
else:
|
||
|
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
|
||
|
self.l_pix_w = train_opt['pixel_weight']
|
||
|
else:
|
||
|
logger.info('Remove pixel loss.')
|
||
|
self.cri_pix = None
|
||
|
|
||
|
|
||
|
# Returns all optimizers used in this step.
|
||
|
def get_optimizers(self):
|
||
|
pass
|
||
|
|
||
|
# Returns optimizers which are opting in for default LR scheduling.
|
||
|
def get_optimizers_with_default_scheduler(self):
|
||
|
pass
|
||
|
|
||
|
# Returns the names of the networks this step will train. Other networks will be frozen.
|
||
|
def get_networks_trained(self):
|
||
|
pass
|
||
|
|
||
|
# Performs all forward and backward passes for this step given an input state. All input states are lists or
|
||
|
# chunked tensors. Use grad_accum_step to derefernce these steps. Return the state with any variables the step
|
||
|
# exports (which may be used by subsequent steps)
|
||
|
def do_forward_backward(self, state, grad_accum_step):
|
||
|
return state
|
||
|
|
||
|
# Performs the optimizer step after all gradient accumulation is completed.
|
||
|
def do_step(self):
|
||
|
pass
|