Allow initial temperature to be specified to SPSR net for inference

This commit is contained in:
James Betker 2020-08-20 11:57:34 -06:00
parent 24bdcc1181
commit 9d77a4db2e
3 changed files with 10 additions and 8 deletions

View File

@ -232,7 +232,7 @@ class SPSRNet(nn.Module):
class SwitchedSpsr(nn.Module):
def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4):
def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10):
super(SwitchedSpsr, self).__init__()
n_upscale = int(math.log(upscale, 2))
@ -254,12 +254,12 @@ class SwitchedSpsr(nn.Module):
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn, transform_block=transform_fn,
attention_norm=True,
transform_count=self.transformation_counts, init_temp=10,
transform_count=self.transformation_counts, init_temp=init_temperature,
add_scalable_noise_to_transforms=True)
self.sw2 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn, transform_block=transform_fn,
attention_norm=True,
transform_count=self.transformation_counts, init_temp=10,
transform_count=self.transformation_counts, init_temp=init_temperature,
add_scalable_noise_to_transforms=True)
self.feature_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False)
self.feature_hr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False)
@ -272,7 +272,7 @@ class SwitchedSpsr(nn.Module):
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, mplex_grad,
pre_transform_block=pretransform_fn, transform_block=transform_fn,
attention_norm=True,
transform_count=self.transformation_counts // 2, init_temp=10,
transform_count=self.transformation_counts // 2, init_temp=init_temperature,
add_scalable_noise_to_transforms=True)
# Upsampling
self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False)
@ -289,7 +289,7 @@ class SwitchedSpsr(nn.Module):
self._branch_pretrain_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn_cat, transform_block=transform_fn_cat,
attention_norm=True,
transform_count=self.transformation_counts, init_temp=10,
transform_count=self.transformation_counts, init_temp=init_temperature,
add_scalable_noise_to_transforms=True)
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) for _ in range(n_upscale)])
self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) for _ in range(n_upscale)])
@ -298,7 +298,7 @@ class SwitchedSpsr(nn.Module):
self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=False)
self.switches = [self.sw1, self.sw2, self.sw_grad, self._branch_pretrain_sw]
self.attentions = None
self.init_temperature = 10
self.init_temperature = init_temperature
self.final_temperature_step = 10000
def forward(self, x):

View File

@ -113,7 +113,8 @@ def define_G(opt, net_key='network_G'):
nb=opt_net['nb'], upscale=opt_net['scale'])
elif which_model == "spsr_switched":
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
netG = spsr.SwitchedSpsr(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'])
netG = spsr.SwitchedSpsr(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
# image corruption
elif which_model == 'HighToLowResNet':

View File

@ -123,6 +123,7 @@ if __name__ == "__main__":
minivid_crf = opt['minivid_crf']
vid_output = opt['mini_vid_output_folder'] if 'mini_vid_output_folder' in opt.keys() else dataset_dir
vid_counter = opt['minivid_start_no'] if 'minivid_start_no' in opt.keys() else 0
img_index = opt['generator_img_index']
ffmpeg_proc = None
tq = tqdm(test_loader)
@ -132,7 +133,7 @@ if __name__ == "__main__":
model.test()
if isinstance(model.fake_H, tuple):
visuals = model.fake_H[0].detach().float().cpu()
visuals = model.fake_H[img_index].detach().float().cpu()
else:
visuals = model.fake_H.detach().float().cpu()
for i in range(visuals.shape[0]):