forked from mrq/DL-Art-School
More fixes to corrupt_fea
This commit is contained in:
parent
0005c56cd4
commit
f9276007a8
|
@ -177,10 +177,7 @@ class SRGANModel(BaseModel):
|
||||||
self.use_corrupted_feature_input = train_opt['corrupted_feature_input'] if 'corrupted_feature_input' in train_opt.keys() else False
|
self.use_corrupted_feature_input = train_opt['corrupted_feature_input'] if 'corrupted_feature_input' in train_opt.keys() else False
|
||||||
if self.use_corrupted_feature_input:
|
if self.use_corrupted_feature_input:
|
||||||
logger.info("Corrupting inputs into the feature network..")
|
logger.info("Corrupting inputs into the feature network..")
|
||||||
self.feature_corruptor = GaussianBlur()
|
self.feature_corruptor = GaussianBlur().to(self.device)
|
||||||
else:
|
|
||||||
logger.info("Using normal inputs into feature network..")
|
|
||||||
print(train_opt)
|
|
||||||
self.netF = networks.define_F(use_bn=False).to(self.device)
|
self.netF = networks.define_F(use_bn=False).to(self.device)
|
||||||
self.lr_netF = None
|
self.lr_netF = None
|
||||||
if 'lr_fea_path' in train_opt.keys():
|
if 'lr_fea_path' in train_opt.keys():
|
||||||
|
@ -502,8 +499,6 @@ class SRGANModel(BaseModel):
|
||||||
elif self.use_corrupted_feature_input:
|
elif self.use_corrupted_feature_input:
|
||||||
cor_Pix = F.interpolate(self.feature_corruptor(pix), size=var_L.shape[2:])
|
cor_Pix = F.interpolate(self.feature_corruptor(pix), size=var_L.shape[2:])
|
||||||
real_fea = self.netF(cor_Pix).detach()
|
real_fea = self.netF(cor_Pix).detach()
|
||||||
if step % 50 == 0:
|
|
||||||
utils.save_image(cor_Pix.detach().cpu(), "corrupted_pix.png")
|
|
||||||
else:
|
else:
|
||||||
real_fea = self.netF(pix).detach()
|
real_fea = self.netF(pix).detach()
|
||||||
if self.use_corrupted_feature_input:
|
if self.use_corrupted_feature_input:
|
||||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
||||||
def main():
|
def main():
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_spsr_switched2.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/finetune_imgset_spsr_switched2_xlbatch_limfeat.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user