Restore test.py for use as standalone validator

This commit is contained in:
James Betker 2020-10-19 15:26:07 -06:00
parent 8ca566b621
commit 76e4f0c086
3 changed files with 11 additions and 8 deletions

View File

@ -745,7 +745,7 @@ class SwitchedSpsr(nn.Module):
weight_init_factor=.1) weight_init_factor=.1)
# Feature branch # Feature branch
self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False) self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn, transform_block=transform_fn, pre_transform_block=pretransform_fn, transform_block=transform_fn,
attention_norm=True, attention_norm=True,
@ -761,7 +761,7 @@ class SwitchedSpsr(nn.Module):
# Grad branch # Grad branch
self.get_g_nopadding = ImageGradientNoPadding() self.get_g_nopadding = ImageGradientNoPadding()
self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False, bias=False) self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False)
mplex_grad = functools.partial(ConvBasisMultiplexer, nf * 2, nf * 2, switch_reductions, mplex_grad = functools.partial(ConvBasisMultiplexer, nf * 2, nf * 2, switch_reductions,
switch_processing_layers, self.transformation_counts // 2, use_exp2=True) switch_processing_layers, self.transformation_counts // 2, use_exp2=True)
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, mplex_grad, self.sw_grad = ConfigurableSwitchComputer(transformation_filters, mplex_grad,

View File

@ -13,7 +13,7 @@ from skimage import io
def main(): def main():
#### options #### options
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../../options/train_exd_imgset_spsr7.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../../options/train_imgset_spsr_switched2_xlbatch.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher') help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)

View File

@ -5,14 +5,16 @@ import argparse
from collections import OrderedDict from collections import OrderedDict
import os import os
import options.options as option
import utils
import utils.options as option
import utils.util as util import utils.util as util
from data.util import bgr2ycbcr from data.util import bgr2ycbcr
import models.archs.SwitchedResidualGenerator_arch as srg import models.archs.SwitchedResidualGenerator_arch as srg
from switched_conv_util import save_attention_to_image, save_attention_to_image_rgb from models.ExtensibleTrainer import ExtensibleTrainer
from switched_conv import compute_attention_specificity from switched_conv.switched_conv_util import save_attention_to_image, save_attention_to_image_rgb
from switched_conv.switched_conv import compute_attention_specificity
from data import create_dataset, create_dataloader from data import create_dataset, create_dataloader
from models import create_model
from tqdm import tqdm from tqdm import tqdm
import torch import torch
import models.networks as networks import models.networks as networks
@ -91,6 +93,7 @@ if __name__ == "__main__":
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/srgan_compute_feature.yml') parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/srgan_compute_feature.yml')
opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt) opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt
util.mkdirs( util.mkdirs(
(path for key, path in opt['path'].items() (path for key, path in opt['path'].items()
@ -108,7 +111,7 @@ if __name__ == "__main__":
logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
test_loaders.append(test_loader) test_loaders.append(test_loader)
model = create_model(opt) model = ExtensibleTrainer(opt)
fea_loss = 0 fea_loss = 0
for test_loader in test_loaders: for test_loader in test_loaders:
test_set_name = test_loader.dataset.opt['name'] test_set_name = test_loader.dataset.opt['name']