forked from mrq/DL-Art-School
Allow initial temperature to be specified to SPSR net for inference
This commit is contained in:
parent
24bdcc1181
commit
9d77a4db2e
|
@ -232,7 +232,7 @@ class SPSRNet(nn.Module):
|
|||
|
||||
|
||||
class SwitchedSpsr(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4):
|
||||
def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10):
|
||||
super(SwitchedSpsr, self).__init__()
|
||||
n_upscale = int(math.log(upscale, 2))
|
||||
|
||||
|
@ -254,12 +254,12 @@ class SwitchedSpsr(nn.Module):
|
|||
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
transform_count=self.transformation_counts, init_temp=10,
|
||||
transform_count=self.transformation_counts, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=True)
|
||||
self.sw2 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
transform_count=self.transformation_counts, init_temp=10,
|
||||
transform_count=self.transformation_counts, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=True)
|
||||
self.feature_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False)
|
||||
self.feature_hr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False)
|
||||
|
@ -272,7 +272,7 @@ class SwitchedSpsr(nn.Module):
|
|||
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, mplex_grad,
|
||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
transform_count=self.transformation_counts // 2, init_temp=10,
|
||||
transform_count=self.transformation_counts // 2, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=True)
|
||||
# Upsampling
|
||||
self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False)
|
||||
|
@ -289,7 +289,7 @@ class SwitchedSpsr(nn.Module):
|
|||
self._branch_pretrain_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn_cat, transform_block=transform_fn_cat,
|
||||
attention_norm=True,
|
||||
transform_count=self.transformation_counts, init_temp=10,
|
||||
transform_count=self.transformation_counts, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=True)
|
||||
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) for _ in range(n_upscale)])
|
||||
self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) for _ in range(n_upscale)])
|
||||
|
@ -298,7 +298,7 @@ class SwitchedSpsr(nn.Module):
|
|||
self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=False)
|
||||
self.switches = [self.sw1, self.sw2, self.sw_grad, self._branch_pretrain_sw]
|
||||
self.attentions = None
|
||||
self.init_temperature = 10
|
||||
self.init_temperature = init_temperature
|
||||
self.final_temperature_step = 10000
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -113,7 +113,8 @@ def define_G(opt, net_key='network_G'):
|
|||
nb=opt_net['nb'], upscale=opt_net['scale'])
|
||||
elif which_model == "spsr_switched":
|
||||
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
|
||||
netG = spsr.SwitchedSpsr(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'])
|
||||
netG = spsr.SwitchedSpsr(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
|
||||
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
|
||||
|
||||
# image corruption
|
||||
elif which_model == 'HighToLowResNet':
|
||||
|
|
|
@ -123,6 +123,7 @@ if __name__ == "__main__":
|
|||
minivid_crf = opt['minivid_crf']
|
||||
vid_output = opt['mini_vid_output_folder'] if 'mini_vid_output_folder' in opt.keys() else dataset_dir
|
||||
vid_counter = opt['minivid_start_no'] if 'minivid_start_no' in opt.keys() else 0
|
||||
img_index = opt['generator_img_index']
|
||||
ffmpeg_proc = None
|
||||
|
||||
tq = tqdm(test_loader)
|
||||
|
@ -132,7 +133,7 @@ if __name__ == "__main__":
|
|||
model.test()
|
||||
|
||||
if isinstance(model.fake_H, tuple):
|
||||
visuals = model.fake_H[0].detach().float().cpu()
|
||||
visuals = model.fake_H[img_index].detach().float().cpu()
|
||||
else:
|
||||
visuals = model.fake_H.detach().float().cpu()
|
||||
for i in range(visuals.shape[0]):
|
||||
|
|
Loading…
Reference in New Issue
Block a user