diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 21d6732a..1d13bc95 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -428,16 +428,22 @@ class SRGANModel(BaseModel): l_g_pix_log = l_g_pix / self.l_pix_w l_g_total += l_g_pix if self.spsr_enabled and self.cri_pix_grad: # gradient pixel loss - var_H_grad_nopadding = self.get_grad_nopadding(var_H) - l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(fake_H_grad, var_H_grad_nopadding) + if self.disjoint_data: + grad_truth = self.get_grad_nopadding(var_L) + grad_pred = F.interpolate(fake_H_grad, size=grad_truth.shape[2:], mode="nearest") + else: + grad_truth = self.get_grad_nopadding(var_H) + grad_pred = fake_H_grad + l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(grad_pred, grad_truth) l_g_total += l_g_pix_grad if self.spsr_enabled and self.cri_pix_branch: # branch pixel loss - grad_truth = self.get_grad_nopadding(var_L) - downsampled_H_branch = fake_H_branch - if grad_truth.shape != fake_H_branch.shape: - downsampled_H_branch = F.interpolate(downsampled_H_branch, size=grad_truth.shape[2:], mode="nearest") - l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(downsampled_H_branch, - grad_truth) + if self.disjoint_data: + grad_truth = self.get_grad_nopadding(var_L) + grad_pred = F.interpolate(fake_H_branch, size=grad_truth.shape[2:], mode="nearest") + else: + grad_truth = self.get_grad_nopadding(var_H) + grad_pred = fake_H_branch + l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(grad_pred, grad_truth) l_g_total += l_g_pix_grad_branch if self.fdpl_enabled and not using_gan_img: l_g_fdpl = self.cri_fdpl(fea_GenOut, pix)