From 9d77a4db2e111522e8a84c2315fe0fdf9b23cd4b Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 20 Aug 2020 11:57:34 -0600 Subject: [PATCH] Allow initial temperature to be specified to SPSR net for inference --- codes/models/archs/SPSR_arch.py | 12 ++++++------ codes/models/networks.py | 3 ++- codes/process_video.py | 3 ++- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index 965d810e..223f8af1 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -232,7 +232,7 @@ class SPSRNet(nn.Module): class SwitchedSpsr(nn.Module): - def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4): + def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10): super(SwitchedSpsr, self).__init__() n_upscale = int(math.log(upscale, 2)) @@ -254,12 +254,12 @@ class SwitchedSpsr(nn.Module): self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, pre_transform_block=pretransform_fn, transform_block=transform_fn, attention_norm=True, - transform_count=self.transformation_counts, init_temp=10, + transform_count=self.transformation_counts, init_temp=init_temperature, add_scalable_noise_to_transforms=True) self.sw2 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, pre_transform_block=pretransform_fn, transform_block=transform_fn, attention_norm=True, - transform_count=self.transformation_counts, init_temp=10, + transform_count=self.transformation_counts, init_temp=init_temperature, add_scalable_noise_to_transforms=True) self.feature_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False) self.feature_hr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False) @@ -272,7 +272,7 @@ class SwitchedSpsr(nn.Module): self.sw_grad = ConfigurableSwitchComputer(transformation_filters, mplex_grad, pre_transform_block=pretransform_fn, transform_block=transform_fn, attention_norm=True, - transform_count=self.transformation_counts // 2, init_temp=10, + transform_count=self.transformation_counts // 2, init_temp=init_temperature, add_scalable_noise_to_transforms=True) # Upsampling self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False) @@ -289,7 +289,7 @@ class SwitchedSpsr(nn.Module): self._branch_pretrain_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, pre_transform_block=pretransform_fn_cat, transform_block=transform_fn_cat, attention_norm=True, - transform_count=self.transformation_counts, init_temp=10, + transform_count=self.transformation_counts, init_temp=init_temperature, add_scalable_noise_to_transforms=True) self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) for _ in range(n_upscale)]) self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) for _ in range(n_upscale)]) @@ -298,7 +298,7 @@ class SwitchedSpsr(nn.Module): self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=False) self.switches = [self.sw1, self.sw2, self.sw_grad, self._branch_pretrain_sw] self.attentions = None - self.init_temperature = 10 + self.init_temperature = init_temperature self.final_temperature_step = 10000 def forward(self, x): diff --git a/codes/models/networks.py b/codes/models/networks.py index 4d8a0d25..6cfd2089 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -113,7 +113,8 @@ def define_G(opt, net_key='network_G'): nb=opt_net['nb'], upscale=opt_net['scale']) elif which_model == "spsr_switched": xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 - netG = spsr.SwitchedSpsr(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale']) + netG = spsr.SwitchedSpsr(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], + init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) # image corruption elif which_model == 'HighToLowResNet': diff --git a/codes/process_video.py b/codes/process_video.py index 696a6b15..21dbfeb9 100644 --- a/codes/process_video.py +++ b/codes/process_video.py @@ -123,6 +123,7 @@ if __name__ == "__main__": minivid_crf = opt['minivid_crf'] vid_output = opt['mini_vid_output_folder'] if 'mini_vid_output_folder' in opt.keys() else dataset_dir vid_counter = opt['minivid_start_no'] if 'minivid_start_no' in opt.keys() else 0 + img_index = opt['generator_img_index'] ffmpeg_proc = None tq = tqdm(test_loader) @@ -132,7 +133,7 @@ if __name__ == "__main__": model.test() if isinstance(model.fake_H, tuple): - visuals = model.fake_H[0].detach().float().cpu() + visuals = model.fake_H[img_index].detach().float().cpu() else: visuals = model.fake_H.detach().float().cpu() for i in range(visuals.shape[0]):