forked from mrq/DL-Art-School
Restore test.py for use as standalone validator
This commit is contained in:
parent
8ca566b621
commit
76e4f0c086
|
@ -745,7 +745,7 @@ class SwitchedSpsr(nn.Module):
|
||||||
weight_init_factor=.1)
|
weight_init_factor=.1)
|
||||||
|
|
||||||
# Feature branch
|
# Feature branch
|
||||||
self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False)
|
self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
|
||||||
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||||
attention_norm=True,
|
attention_norm=True,
|
||||||
|
@ -761,7 +761,7 @@ class SwitchedSpsr(nn.Module):
|
||||||
|
|
||||||
# Grad branch
|
# Grad branch
|
||||||
self.get_g_nopadding = ImageGradientNoPadding()
|
self.get_g_nopadding = ImageGradientNoPadding()
|
||||||
self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False, bias=False)
|
self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False)
|
||||||
mplex_grad = functools.partial(ConvBasisMultiplexer, nf * 2, nf * 2, switch_reductions,
|
mplex_grad = functools.partial(ConvBasisMultiplexer, nf * 2, nf * 2, switch_reductions,
|
||||||
switch_processing_layers, self.transformation_counts // 2, use_exp2=True)
|
switch_processing_layers, self.transformation_counts // 2, use_exp2=True)
|
||||||
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, mplex_grad,
|
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, mplex_grad,
|
||||||
|
|
|
@ -13,7 +13,7 @@ from skimage import io
|
||||||
def main():
|
def main():
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../../options/train_exd_imgset_spsr7.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../../options/train_imgset_spsr_switched2_xlbatch.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
|
|
@ -5,14 +5,16 @@ import argparse
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import options.options as option
|
|
||||||
|
import utils
|
||||||
|
import utils.options as option
|
||||||
import utils.util as util
|
import utils.util as util
|
||||||
from data.util import bgr2ycbcr
|
from data.util import bgr2ycbcr
|
||||||
import models.archs.SwitchedResidualGenerator_arch as srg
|
import models.archs.SwitchedResidualGenerator_arch as srg
|
||||||
from switched_conv_util import save_attention_to_image, save_attention_to_image_rgb
|
from models.ExtensibleTrainer import ExtensibleTrainer
|
||||||
from switched_conv import compute_attention_specificity
|
from switched_conv.switched_conv_util import save_attention_to_image, save_attention_to_image_rgb
|
||||||
|
from switched_conv.switched_conv import compute_attention_specificity
|
||||||
from data import create_dataset, create_dataloader
|
from data import create_dataset, create_dataloader
|
||||||
from models import create_model
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
import models.networks as networks
|
import models.networks as networks
|
||||||
|
@ -91,6 +93,7 @@ if __name__ == "__main__":
|
||||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/srgan_compute_feature.yml')
|
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/srgan_compute_feature.yml')
|
||||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||||
opt = option.dict_to_nonedict(opt)
|
opt = option.dict_to_nonedict(opt)
|
||||||
|
utils.util.loaded_options = opt
|
||||||
|
|
||||||
util.mkdirs(
|
util.mkdirs(
|
||||||
(path for key, path in opt['path'].items()
|
(path for key, path in opt['path'].items()
|
||||||
|
@ -108,7 +111,7 @@ if __name__ == "__main__":
|
||||||
logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
|
logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
|
||||||
test_loaders.append(test_loader)
|
test_loaders.append(test_loader)
|
||||||
|
|
||||||
model = create_model(opt)
|
model = ExtensibleTrainer(opt)
|
||||||
fea_loss = 0
|
fea_loss = 0
|
||||||
for test_loader in test_loaders:
|
for test_loader in test_loaders:
|
||||||
test_set_name = test_loader.dataset.opt['name']
|
test_set_name = test_loader.dataset.opt['name']
|
||||||
|
|
Loading…
Reference in New Issue
Block a user