Adapt srg2 for video

This commit is contained in:
James Betker 2020-11-10 16:16:41 -07:00
parent b742d1e5a5
commit a1760f8969
2 changed files with 17 additions and 5 deletions

View File

@ -108,10 +108,14 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
def __init__(self, switch_depth, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
heightened_final_step=50000, upsample_factor=1,
add_scalable_noise_to_transforms=False):
add_scalable_noise_to_transforms=False, for_video=False):
super(ConfigurableSwitchedResidualGenerator2, self).__init__()
switches = []
self.initial_conv = ConvBnLelu(3, transformation_filters, norm=False, activation=False, bias=True)
self.for_video = for_video
if for_video:
self.initial_conv = ConvBnLelu(6, transformation_filters, stride=upsample_factor, norm=False, activation=False, bias=True)
else:
self.initial_conv = ConvBnLelu(3, transformation_filters, norm=False, activation=False, bias=True)
self.upconv1 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
self.upconv2 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
self.hr_conv = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
@ -135,8 +139,15 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
self.upsample_factor = upsample_factor
assert self.upsample_factor == 2 or self.upsample_factor == 4
def forward(self, x):
x = self.initial_conv(x)
def forward(self, x, ref=None):
if self.for_video:
x_lg = F.interpolate(x, scale_factor=self.upsample_factor, mode="bicubic")
if ref is None:
ref = torch.zeros_like(x_lg)
x_lg = torch.cat([x_lg, ref], dim=1)
else:
x_lg = x
x = self.initial_conv(x_lg)
self.attentions = []
for i, sw in enumerate(self.switches):

View File

@ -67,7 +67,8 @@ def define_G(opt, net_key='network_G', scale=None):
transformation_filters=opt_net['transformation_filters'], attention_norm=opt_net['attention_norm'],
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'],
for_video=opt_net['for_video'])
elif which_model == "srg2classic":
netG = srg2_classic.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'],
switch_reductions=opt_net['switch_reductions'],