Use LR data for image gradient prediction when HR data is disjoint

This commit is contained in:
James Betker 2020-08-10 15:00:28 -06:00
parent f0e2816239
commit cb316fabc7

View File

@ -428,16 +428,22 @@ class SRGANModel(BaseModel):
l_g_pix_log = l_g_pix / self.l_pix_w l_g_pix_log = l_g_pix / self.l_pix_w
l_g_total += l_g_pix l_g_total += l_g_pix
if self.spsr_enabled and self.cri_pix_grad: # gradient pixel loss if self.spsr_enabled and self.cri_pix_grad: # gradient pixel loss
var_H_grad_nopadding = self.get_grad_nopadding(var_H) if self.disjoint_data:
l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(fake_H_grad, var_H_grad_nopadding) grad_truth = self.get_grad_nopadding(var_L)
grad_pred = F.interpolate(fake_H_grad, size=grad_truth.shape[2:], mode="nearest")
else:
grad_truth = self.get_grad_nopadding(var_H)
grad_pred = fake_H_grad
l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(grad_pred, grad_truth)
l_g_total += l_g_pix_grad l_g_total += l_g_pix_grad
if self.spsr_enabled and self.cri_pix_branch: # branch pixel loss if self.spsr_enabled and self.cri_pix_branch: # branch pixel loss
if self.disjoint_data:
grad_truth = self.get_grad_nopadding(var_L) grad_truth = self.get_grad_nopadding(var_L)
downsampled_H_branch = fake_H_branch grad_pred = F.interpolate(fake_H_branch, size=grad_truth.shape[2:], mode="nearest")
if grad_truth.shape != fake_H_branch.shape: else:
downsampled_H_branch = F.interpolate(downsampled_H_branch, size=grad_truth.shape[2:], mode="nearest") grad_truth = self.get_grad_nopadding(var_H)
l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(downsampled_H_branch, grad_pred = fake_H_branch
grad_truth) l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(grad_pred, grad_truth)
l_g_total += l_g_pix_grad_branch l_g_total += l_g_pix_grad_branch
if self.fdpl_enabled and not using_gan_img: if self.fdpl_enabled and not using_gan_img:
l_g_fdpl = self.cri_fdpl(fea_GenOut, pix) l_g_fdpl = self.cri_fdpl(fea_GenOut, pix)