forked from mrq/DL-Art-School
train HR feature trainer
This commit is contained in:
parent
0e859a8082
commit
a56e906f9c
|
@ -74,10 +74,10 @@ class FeatureModel(BaseModel):
|
||||||
self.optimizer_G.zero_grad()
|
self.optimizer_G.zero_grad()
|
||||||
|
|
||||||
# grey out the LR image but keep 3 channels.
|
# grey out the LR image but keep 3 channels.
|
||||||
lr = torch.mean(self.var_L, dim=1, keepdim=True)
|
lr = torch.mean(self.var_H, dim=1, keepdim=True)
|
||||||
lr = lr.repeat(1, 3, 1, 1)
|
lr = lr.repeat(1, 3, 1, 1)
|
||||||
|
|
||||||
self.fake_H = self.fea_train(lr, interpolate_factor=2)
|
self.fake_H = self.fea_train(lr, interpolate_factor=1)
|
||||||
ref_H = self.net_ref(self.real_H)
|
ref_H = self.net_ref(self.real_H)
|
||||||
l_fea = self.cri_fea(self.fake_H, ref_H)
|
l_fea = self.cri_fea(self.fake_H, ref_H)
|
||||||
l_fea.backward()
|
l_fea.backward()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user