Adapt srg2 for video
This commit is contained in:
parent
b742d1e5a5
commit
a1760f8969
|
@ -108,10 +108,14 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
||||||
def __init__(self, switch_depth, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
|
def __init__(self, switch_depth, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
|
||||||
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
|
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
|
||||||
heightened_final_step=50000, upsample_factor=1,
|
heightened_final_step=50000, upsample_factor=1,
|
||||||
add_scalable_noise_to_transforms=False):
|
add_scalable_noise_to_transforms=False, for_video=False):
|
||||||
super(ConfigurableSwitchedResidualGenerator2, self).__init__()
|
super(ConfigurableSwitchedResidualGenerator2, self).__init__()
|
||||||
switches = []
|
switches = []
|
||||||
self.initial_conv = ConvBnLelu(3, transformation_filters, norm=False, activation=False, bias=True)
|
self.for_video = for_video
|
||||||
|
if for_video:
|
||||||
|
self.initial_conv = ConvBnLelu(6, transformation_filters, stride=upsample_factor, norm=False, activation=False, bias=True)
|
||||||
|
else:
|
||||||
|
self.initial_conv = ConvBnLelu(3, transformation_filters, norm=False, activation=False, bias=True)
|
||||||
self.upconv1 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
|
self.upconv1 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
|
||||||
self.upconv2 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
|
self.upconv2 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
|
||||||
self.hr_conv = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
|
self.hr_conv = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
|
||||||
|
@ -135,8 +139,15 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
||||||
self.upsample_factor = upsample_factor
|
self.upsample_factor = upsample_factor
|
||||||
assert self.upsample_factor == 2 or self.upsample_factor == 4
|
assert self.upsample_factor == 2 or self.upsample_factor == 4
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, ref=None):
|
||||||
x = self.initial_conv(x)
|
if self.for_video:
|
||||||
|
x_lg = F.interpolate(x, scale_factor=self.upsample_factor, mode="bicubic")
|
||||||
|
if ref is None:
|
||||||
|
ref = torch.zeros_like(x_lg)
|
||||||
|
x_lg = torch.cat([x_lg, ref], dim=1)
|
||||||
|
else:
|
||||||
|
x_lg = x
|
||||||
|
x = self.initial_conv(x_lg)
|
||||||
|
|
||||||
self.attentions = []
|
self.attentions = []
|
||||||
for i, sw in enumerate(self.switches):
|
for i, sw in enumerate(self.switches):
|
||||||
|
|
|
@ -67,7 +67,8 @@ def define_G(opt, net_key='network_G', scale=None):
|
||||||
transformation_filters=opt_net['transformation_filters'], attention_norm=opt_net['attention_norm'],
|
transformation_filters=opt_net['transformation_filters'], attention_norm=opt_net['attention_norm'],
|
||||||
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
||||||
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
||||||
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'],
|
||||||
|
for_video=opt_net['for_video'])
|
||||||
elif which_model == "srg2classic":
|
elif which_model == "srg2classic":
|
||||||
netG = srg2_classic.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'],
|
netG = srg2_classic.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'],
|
||||||
switch_reductions=opt_net['switch_reductions'],
|
switch_reductions=opt_net['switch_reductions'],
|
||||||
|
|
Loading…
Reference in New Issue
Block a user