Use LR data for image gradient prediction when HR data is disjoint

This commit is contained in:
James Betker 2020-08-10 15:00:28 -06:00
parent f0e2816239
commit cb316fabc7

View File

@ -428,16 +428,22 @@ class SRGANModel(BaseModel):
l_g_pix_log = l_g_pix / self.l_pix_w l_g_pix_log = l_g_pix / self.l_pix_w
l_g_total += l_g_pix l_g_total += l_g_pix
if self.spsr_enabled and self.cri_pix_grad: # gradient pixel loss if self.spsr_enabled and self.cri_pix_grad: # gradient pixel loss
var_H_grad_nopadding = self.get_grad_nopadding(var_H) if self.disjoint_data:
l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(fake_H_grad, var_H_grad_nopadding) grad_truth = self.get_grad_nopadding(var_L)
grad_pred = F.interpolate(fake_H_grad, size=grad_truth.shape[2:], mode="nearest")
else:
grad_truth = self.get_grad_nopadding(var_H)
grad_pred = fake_H_grad
l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(grad_pred, grad_truth)
l_g_total += l_g_pix_grad l_g_total += l_g_pix_grad
if self.spsr_enabled and self.cri_pix_branch: # branch pixel loss if self.spsr_enabled and self.cri_pix_branch: # branch pixel loss
grad_truth = self.get_grad_nopadding(var_L) if self.disjoint_data:
downsampled_H_branch = fake_H_branch grad_truth = self.get_grad_nopadding(var_L)
if grad_truth.shape != fake_H_branch.shape: grad_pred = F.interpolate(fake_H_branch, size=grad_truth.shape[2:], mode="nearest")
downsampled_H_branch = F.interpolate(downsampled_H_branch, size=grad_truth.shape[2:], mode="nearest") else:
l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(downsampled_H_branch, grad_truth = self.get_grad_nopadding(var_H)
grad_truth) grad_pred = fake_H_branch
l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(grad_pred, grad_truth)
l_g_total += l_g_pix_grad_branch l_g_total += l_g_pix_grad_branch
if self.fdpl_enabled and not using_gan_img: if self.fdpl_enabled and not using_gan_img:
l_g_fdpl = self.cri_fdpl(fea_GenOut, pix) l_g_fdpl = self.cri_fdpl(fea_GenOut, pix)