train HR feature trainer

This commit is contained in:
James Betker 2020-08-29 09:27:48 -06:00
parent 0e859a8082
commit a56e906f9c

View File

@ -74,10 +74,10 @@ class FeatureModel(BaseModel):
self.optimizer_G.zero_grad() self.optimizer_G.zero_grad()
# grey out the LR image but keep 3 channels. # grey out the LR image but keep 3 channels.
lr = torch.mean(self.var_L, dim=1, keepdim=True) lr = torch.mean(self.var_H, dim=1, keepdim=True)
lr = lr.repeat(1, 3, 1, 1) lr = lr.repeat(1, 3, 1, 1)
self.fake_H = self.fea_train(lr, interpolate_factor=2) self.fake_H = self.fea_train(lr, interpolate_factor=1)
ref_H = self.net_ref(self.real_H) ref_H = self.net_ref(self.real_H)
l_fea = self.cri_fea(self.fake_H, ref_H) l_fea = self.cri_fea(self.fake_H, ref_H)
l_fea.backward() l_fea.backward()