Grey feature
This commit is contained in:
parent
e59e712e39
commit
afdd93fbe9
|
@ -72,7 +72,12 @@ class FeatureModel(BaseModel):
|
|||
|
||||
def optimize_parameters(self, step):
|
||||
self.optimizer_G.zero_grad()
|
||||
self.fake_H = self.fea_train(self.var_L, interpolate_factor=2)
|
||||
|
||||
# grey out the LR image but keep 3 channels.
|
||||
lr = torch.mean(self.var_L, dim=1, keepdim=True)
|
||||
lr = lr.repeat(1, 3, 1, 1)
|
||||
|
||||
self.fake_H = self.fea_train(lr, interpolate_factor=2)
|
||||
ref_H = self.net_ref(self.real_H)
|
||||
l_fea = self.cri_fea(self.fake_H, ref_H)
|
||||
l_fea.backward()
|
||||
|
|
|
@ -195,7 +195,6 @@ def define_fixed_D(opt):
|
|||
|
||||
# Define network used for perceptual loss
|
||||
def define_F(which_model='vgg', use_bn=False, for_training=False, load_path=None):
|
||||
device = torch.device('cuda' if gpu_ids else 'cpu')
|
||||
if which_model == 'vgg':
|
||||
# PyTorch pretrained VGG19-54, before ReLU.
|
||||
if use_bn:
|
||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_spsr_switched2.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_feature_net.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
|
Loading…
Reference in New Issue
Block a user