corfea debugging

This commit is contained in:
James Betker 2020-08-23 17:39:02 -06:00
parent 7713cb8df5
commit 4bb5b3c981

View File

@ -176,6 +176,7 @@ class SRGANModel(BaseModel):
if self.cri_fea: # load VGG perceptual loss if self.cri_fea: # load VGG perceptual loss
self.use_corrupted_feature_input = train_opt['corrupted_feature_input'] if 'corrupted_feature_input' in train_opt.keys() else False self.use_corrupted_feature_input = train_opt['corrupted_feature_input'] if 'corrupted_feature_input' in train_opt.keys() else False
if self.use_corrupted_feature_input: if self.use_corrupted_feature_input:
logger.info("Corrupting inputs into the feature network..")
self.feature_corruptor = GaussianBlur() self.feature_corruptor = GaussianBlur()
self.netF = networks.define_F(use_bn=False).to(self.device) self.netF = networks.define_F(use_bn=False).to(self.device)
self.lr_netF = None self.lr_netF = None
@ -498,6 +499,8 @@ class SRGANModel(BaseModel):
elif self.use_corrupted_feature_input: elif self.use_corrupted_feature_input:
cor_Pix = F.interpolate(self.feature_corruptor(pix), size=var_L.shape[2:]) cor_Pix = F.interpolate(self.feature_corruptor(pix), size=var_L.shape[2:])
real_fea = self.netF(cor_Pix).detach() real_fea = self.netF(cor_Pix).detach()
if step % 50 == 0:
utils.save_image(cor_Pix.detach().cpu(), "corrupted_pix.png")
else: else:
real_fea = self.netF(pix).detach() real_fea = self.netF(pix).detach()
if self.use_corrupted_feature_input: if self.use_corrupted_feature_input: