forked from mrq/DL-Art-School
Use LR data for image gradient prediction when HR data is disjoint
This commit is contained in:
parent
f0e2816239
commit
cb316fabc7
|
@ -428,16 +428,22 @@ class SRGANModel(BaseModel):
|
|||
l_g_pix_log = l_g_pix / self.l_pix_w
|
||||
l_g_total += l_g_pix
|
||||
if self.spsr_enabled and self.cri_pix_grad: # gradient pixel loss
|
||||
var_H_grad_nopadding = self.get_grad_nopadding(var_H)
|
||||
l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(fake_H_grad, var_H_grad_nopadding)
|
||||
if self.disjoint_data:
|
||||
grad_truth = self.get_grad_nopadding(var_L)
|
||||
grad_pred = F.interpolate(fake_H_grad, size=grad_truth.shape[2:], mode="nearest")
|
||||
else:
|
||||
grad_truth = self.get_grad_nopadding(var_H)
|
||||
grad_pred = fake_H_grad
|
||||
l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(grad_pred, grad_truth)
|
||||
l_g_total += l_g_pix_grad
|
||||
if self.spsr_enabled and self.cri_pix_branch: # branch pixel loss
|
||||
if self.disjoint_data:
|
||||
grad_truth = self.get_grad_nopadding(var_L)
|
||||
downsampled_H_branch = fake_H_branch
|
||||
if grad_truth.shape != fake_H_branch.shape:
|
||||
downsampled_H_branch = F.interpolate(downsampled_H_branch, size=grad_truth.shape[2:], mode="nearest")
|
||||
l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(downsampled_H_branch,
|
||||
grad_truth)
|
||||
grad_pred = F.interpolate(fake_H_branch, size=grad_truth.shape[2:], mode="nearest")
|
||||
else:
|
||||
grad_truth = self.get_grad_nopadding(var_H)
|
||||
grad_pred = fake_H_branch
|
||||
l_g_pix_grad_branch = self.l_pix_branch_w * self.cri_pix_branch(grad_pred, grad_truth)
|
||||
l_g_total += l_g_pix_grad_branch
|
||||
if self.fdpl_enabled and not using_gan_img:
|
||||
l_g_fdpl = self.cri_fdpl(fea_GenOut, pix)
|
||||
|
|
Loading…
Reference in New Issue
Block a user