forked from mrq/DL-Art-School
Adapt srg2 for video
This commit is contained in:
parent
b742d1e5a5
commit
a1760f8969
|
@ -108,10 +108,14 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
|||
def __init__(self, switch_depth, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
|
||||
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
|
||||
heightened_final_step=50000, upsample_factor=1,
|
||||
add_scalable_noise_to_transforms=False):
|
||||
add_scalable_noise_to_transforms=False, for_video=False):
|
||||
super(ConfigurableSwitchedResidualGenerator2, self).__init__()
|
||||
switches = []
|
||||
self.initial_conv = ConvBnLelu(3, transformation_filters, norm=False, activation=False, bias=True)
|
||||
self.for_video = for_video
|
||||
if for_video:
|
||||
self.initial_conv = ConvBnLelu(6, transformation_filters, stride=upsample_factor, norm=False, activation=False, bias=True)
|
||||
else:
|
||||
self.initial_conv = ConvBnLelu(3, transformation_filters, norm=False, activation=False, bias=True)
|
||||
self.upconv1 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
|
||||
self.upconv2 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
|
||||
self.hr_conv = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
|
||||
|
@ -135,8 +139,15 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
|||
self.upsample_factor = upsample_factor
|
||||
assert self.upsample_factor == 2 or self.upsample_factor == 4
|
||||
|
||||
def forward(self, x):
|
||||
x = self.initial_conv(x)
|
||||
def forward(self, x, ref=None):
|
||||
if self.for_video:
|
||||
x_lg = F.interpolate(x, scale_factor=self.upsample_factor, mode="bicubic")
|
||||
if ref is None:
|
||||
ref = torch.zeros_like(x_lg)
|
||||
x_lg = torch.cat([x_lg, ref], dim=1)
|
||||
else:
|
||||
x_lg = x
|
||||
x = self.initial_conv(x_lg)
|
||||
|
||||
self.attentions = []
|
||||
for i, sw in enumerate(self.switches):
|
||||
|
|
|
@ -67,7 +67,8 @@ def define_G(opt, net_key='network_G', scale=None):
|
|||
transformation_filters=opt_net['transformation_filters'], attention_norm=opt_net['attention_norm'],
|
||||
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
||||
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
||||
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
||||
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'],
|
||||
for_video=opt_net['for_video'])
|
||||
elif which_model == "srg2classic":
|
||||
netG = srg2_classic.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'],
|
||||
switch_reductions=opt_net['switch_reductions'],
|
||||
|
|
Loading…
Reference in New Issue
Block a user