From 76e4f0c086274c93480f114cb76d1513d5a3b3fb Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Mon, 19 Oct 2020 15:26:07 -0600
Subject: [PATCH] Restore test.py for use as standalone validator

---
 codes/models/archs/SPSR_arch.py |  4 ++--
 codes/scripts/validate_data.py  |  2 +-
 codes/test.py                   | 13 ++++++++-----
 3 files changed, 11 insertions(+), 8 deletions(-)

diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py
index ad23a5fd..ab8277e7 100644
--- a/codes/models/archs/SPSR_arch.py
+++ b/codes/models/archs/SPSR_arch.py
@@ -745,7 +745,7 @@ class SwitchedSpsr(nn.Module):
                                          weight_init_factor=.1)
 
         # Feature branch
-        self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False)
+        self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
         self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
                                                    pre_transform_block=pretransform_fn, transform_block=transform_fn,
                                                    attention_norm=True,
@@ -761,7 +761,7 @@ class SwitchedSpsr(nn.Module):
 
         # Grad branch
         self.get_g_nopadding = ImageGradientNoPadding()
-        self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False, bias=False)
+        self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False)
         mplex_grad = functools.partial(ConvBasisMultiplexer, nf * 2, nf * 2, switch_reductions,
                                         switch_processing_layers, self.transformation_counts // 2, use_exp2=True)
         self.sw_grad = ConfigurableSwitchComputer(transformation_filters, mplex_grad,
diff --git a/codes/scripts/validate_data.py b/codes/scripts/validate_data.py
index e96084e0..c2fb28cd 100644
--- a/codes/scripts/validate_data.py
+++ b/codes/scripts/validate_data.py
@@ -13,7 +13,7 @@ from skimage import io
 def main():
     #### options
     parser = argparse.ArgumentParser()
-    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../../options/train_exd_imgset_spsr7.yml')
+    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../../options/train_imgset_spsr_switched2_xlbatch.yml')
     parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
                         help='job launcher')
     parser.add_argument('--local_rank', type=int, default=0)
diff --git a/codes/test.py b/codes/test.py
index f025290b..6c783a14 100644
--- a/codes/test.py
+++ b/codes/test.py
@@ -5,14 +5,16 @@ import argparse
 from collections import OrderedDict
 
 import os
-import options.options as option
+
+import utils
+import utils.options as option
 import utils.util as util
 from data.util import bgr2ycbcr
 import models.archs.SwitchedResidualGenerator_arch as srg
-from switched_conv_util import save_attention_to_image, save_attention_to_image_rgb
-from switched_conv import compute_attention_specificity
+from models.ExtensibleTrainer import ExtensibleTrainer
+from switched_conv.switched_conv_util import save_attention_to_image, save_attention_to_image_rgb
+from switched_conv.switched_conv import compute_attention_specificity
 from data import create_dataset, create_dataloader
-from models import create_model
 from tqdm import tqdm
 import torch
 import models.networks as networks
@@ -91,6 +93,7 @@ if __name__ == "__main__":
     parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/srgan_compute_feature.yml')
     opt = option.parse(parser.parse_args().opt, is_train=False)
     opt = option.dict_to_nonedict(opt)
+    utils.util.loaded_options = opt
 
     util.mkdirs(
         (path for key, path in opt['path'].items()
@@ -108,7 +111,7 @@ if __name__ == "__main__":
         logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
         test_loaders.append(test_loader)
 
-    model = create_model(opt)
+    model = ExtensibleTrainer(opt)
     fea_loss = 0
     for test_loader in test_loaders:
         test_set_name = test_loader.dataset.opt['name']