Restore test.py for use as standalone validator
This commit is contained in:
parent
8ca566b621
commit
76e4f0c086
|
@ -745,7 +745,7 @@ class SwitchedSpsr(nn.Module):
|
|||
weight_init_factor=.1)
|
||||
|
||||
# Feature branch
|
||||
self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False)
|
||||
self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
|
||||
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
|
@ -761,7 +761,7 @@ class SwitchedSpsr(nn.Module):
|
|||
|
||||
# Grad branch
|
||||
self.get_g_nopadding = ImageGradientNoPadding()
|
||||
self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False, bias=False)
|
||||
self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False)
|
||||
mplex_grad = functools.partial(ConvBasisMultiplexer, nf * 2, nf * 2, switch_reductions,
|
||||
switch_processing_layers, self.transformation_counts // 2, use_exp2=True)
|
||||
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, mplex_grad,
|
||||
|
|
|
@ -13,7 +13,7 @@ from skimage import io
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../../options/train_exd_imgset_spsr7.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../../options/train_imgset_spsr_switched2_xlbatch.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
|
|
@ -5,14 +5,16 @@ import argparse
|
|||
from collections import OrderedDict
|
||||
|
||||
import os
|
||||
import options.options as option
|
||||
|
||||
import utils
|
||||
import utils.options as option
|
||||
import utils.util as util
|
||||
from data.util import bgr2ycbcr
|
||||
import models.archs.SwitchedResidualGenerator_arch as srg
|
||||
from switched_conv_util import save_attention_to_image, save_attention_to_image_rgb
|
||||
from switched_conv import compute_attention_specificity
|
||||
from models.ExtensibleTrainer import ExtensibleTrainer
|
||||
from switched_conv.switched_conv_util import save_attention_to_image, save_attention_to_image_rgb
|
||||
from switched_conv.switched_conv import compute_attention_specificity
|
||||
from data import create_dataset, create_dataloader
|
||||
from models import create_model
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import models.networks as networks
|
||||
|
@ -91,6 +93,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/srgan_compute_feature.yml')
|
||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
utils.util.loaded_options = opt
|
||||
|
||||
util.mkdirs(
|
||||
(path for key, path in opt['path'].items()
|
||||
|
@ -108,7 +111,7 @@ if __name__ == "__main__":
|
|||
logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
|
||||
test_loaders.append(test_loader)
|
||||
|
||||
model = create_model(opt)
|
||||
model = ExtensibleTrainer(opt)
|
||||
fea_loss = 0
|
||||
for test_loader in test_loaders:
|
||||
test_set_name = test_loader.dataset.opt['name']
|
||||
|
|
Loading…
Reference in New Issue
Block a user