Grey feature

This commit is contained in:
James Betker 2020-08-22 13:41:38 -06:00
parent e59e712e39
commit afdd93fbe9
3 changed files with 7 additions and 3 deletions

View File

@ -72,7 +72,12 @@ class FeatureModel(BaseModel):
def optimize_parameters(self, step):
self.optimizer_G.zero_grad()
self.fake_H = self.fea_train(self.var_L, interpolate_factor=2)
# grey out the LR image but keep 3 channels.
lr = torch.mean(self.var_L, dim=1, keepdim=True)
lr = lr.repeat(1, 3, 1, 1)
self.fake_H = self.fea_train(lr, interpolate_factor=2)
ref_H = self.net_ref(self.real_H)
l_fea = self.cri_fea(self.fake_H, ref_H)
l_fea.backward()

View File

@ -195,7 +195,6 @@ def define_fixed_D(opt):
# Define network used for perceptual loss
def define_F(which_model='vgg', use_bn=False, for_training=False, load_path=None):
device = torch.device('cuda' if gpu_ids else 'cpu')
if which_model == 'vgg':
# PyTorch pretrained VGG19-54, before ReLU.
if use_bn:

View File

@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_spsr_switched2.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_feature_net.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)