diff --git a/codes/models/archs/srg2_classic.py b/codes/models/archs/srg2_classic.py index 0415caa6..b33cccca 100644 --- a/codes/models/archs/srg2_classic.py +++ b/codes/models/archs/srg2_classic.py @@ -108,10 +108,14 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module): def __init__(self, switch_depth, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1, heightened_final_step=50000, upsample_factor=1, - add_scalable_noise_to_transforms=False): + add_scalable_noise_to_transforms=False, for_video=False): super(ConfigurableSwitchedResidualGenerator2, self).__init__() switches = [] - self.initial_conv = ConvBnLelu(3, transformation_filters, norm=False, activation=False, bias=True) + self.for_video = for_video + if for_video: + self.initial_conv = ConvBnLelu(6, transformation_filters, stride=upsample_factor, norm=False, activation=False, bias=True) + else: + self.initial_conv = ConvBnLelu(3, transformation_filters, norm=False, activation=False, bias=True) self.upconv1 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True) self.upconv2 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True) self.hr_conv = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True) @@ -135,8 +139,15 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module): self.upsample_factor = upsample_factor assert self.upsample_factor == 2 or self.upsample_factor == 4 - def forward(self, x): - x = self.initial_conv(x) + def forward(self, x, ref=None): + if self.for_video: + x_lg = F.interpolate(x, scale_factor=self.upsample_factor, mode="bicubic") + if ref is None: + ref = torch.zeros_like(x_lg) + x_lg = torch.cat([x_lg, ref], dim=1) + else: + x_lg = x + x = self.initial_conv(x_lg) self.attentions = [] for i, sw in enumerate(self.switches): diff --git a/codes/models/networks.py b/codes/models/networks.py index feefbd69..e5a79536 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -67,7 +67,8 @@ def define_G(opt, net_key='network_G', scale=None): transformation_filters=opt_net['transformation_filters'], attention_norm=opt_net['attention_norm'], initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'], heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'], - upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise']) + upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'], + for_video=opt_net['for_video']) elif which_model == "srg2classic": netG = srg2_classic.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'], switch_reductions=opt_net['switch_reductions'],