more debugging

This commit is contained in:
James Betker 2020-08-25 18:11:53 -06:00
parent 3f60281da7
commit 83f2f8d239

View File

@ -77,6 +77,7 @@ class SRGANModel(BaseModel):
if self.is_train: if self.is_train:
self.netD = networks.define_D(opt).to(self.device) self.netD = networks.define_D(opt).to(self.device)
if self.spsr_enabled: if self.spsr_enabled:
logger.info("Defining grad net...")
self.netD_grad = networks.define_D(opt, wrap=True).to(self.device) # D_grad self.netD_grad = networks.define_D(opt, wrap=True).to(self.device) # D_grad
if 'network_C' in opt.keys(): if 'network_C' in opt.keys():
@ -352,7 +353,7 @@ class SRGANModel(BaseModel):
self.img_debug_steps = opt['logger']['img_debug_steps'] if 'img_debug_steps' in opt['logger'].keys() else 50 self.img_debug_steps = opt['logger']['img_debug_steps'] if 'img_debug_steps' in opt['logger'].keys() else 50
self.print_network() # print network #self.print_network() # print network
self.load() # load G and D if needed self.load() # load G and D if needed
self.load_random_corruptor() self.load_random_corruptor()