DL-Art-School/codes/models/steps/srgan_generator_step.py
2020-08-18 08:49:32 -06:00

46 lines
1.7 KiB
Python

# Defines the expected API for a step
class SrGanGeneratorStep:
def __init__(self, opt_step, opt, netsG, netsD):
self.step_opt = opt_step
self.opt = opt
self.gen = netsG['base']
self.disc = netsD['base']
for loss in self.step_opt['losses']:
# G pixel loss
if train_opt['pixel_weight'] > 0:
l_pix_type = train_opt['pixel_criterion']
if l_pix_type == 'l1':
self.cri_pix = nn.L1Loss().to(self.device)
elif l_pix_type == 'l2':
self.cri_pix = nn.MSELoss().to(self.device)
else:
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
self.l_pix_w = train_opt['pixel_weight']
else:
logger.info('Remove pixel loss.')
self.cri_pix = None
# Returns all optimizers used in this step.
def get_optimizers(self):
pass
# Returns optimizers which are opting in for default LR scheduling.
def get_optimizers_with_default_scheduler(self):
pass
# Returns the names of the networks this step will train. Other networks will be frozen.
def get_networks_trained(self):
pass
# Performs all forward and backward passes for this step given an input state. All input states are lists or
# chunked tensors. Use grad_accum_step to derefernce these steps. Return the state with any variables the step
# exports (which may be used by subsequent steps)
def do_forward_backward(self, state, grad_accum_step):
return state
# Performs the optimizer step after all gradient accumulation is completed.
def do_step(self):
pass