extra_conv in gn discriminator, multiframe support in rrdb.
This commit is contained in:
parent
da604752e6
commit
1e0f69e34b
|
@ -177,6 +177,7 @@ class RRDBNet(nn.Module):
|
|||
feature_channels=64, # Only applicable when headless=True. How many channels are used at the trunk level.
|
||||
output_mode="hq_only", # Options: "hq_only", "hq+features", "features_only"
|
||||
initial_stride=1,
|
||||
use_ref=False, # When set, a reference image is expected as input and synthesized if not found. Useful for video SR.
|
||||
):
|
||||
super(RRDBNet, self).__init__()
|
||||
assert output_mode in ['hq_only', 'hq+features', 'features_only']
|
||||
|
@ -186,7 +187,8 @@ class RRDBNet(nn.Module):
|
|||
self.scale = scale
|
||||
self.in_channels = in_channels
|
||||
self.output_mode = output_mode
|
||||
first_conv_stride = initial_stride if in_channels <= 4 else scale
|
||||
self.use_ref = use_ref
|
||||
first_conv_stride = initial_stride if not self.use_ref else scale
|
||||
first_conv_ksize = 3 if first_conv_stride == 1 else 7
|
||||
first_conv_padding = 1 if first_conv_stride == 1 else 3
|
||||
if headless:
|
||||
|
@ -242,7 +244,7 @@ class RRDBNet(nn.Module):
|
|||
feat = x
|
||||
else:
|
||||
# "Normal" mode -> image input.
|
||||
if self.in_channels > 4:
|
||||
if self.use_ref:
|
||||
x_lg = F.interpolate(x, scale_factor=self.scale, mode="bicubic")
|
||||
if ref is None:
|
||||
ref = torch.zeros_like(x_lg)
|
||||
|
|
|
@ -83,7 +83,7 @@ class Discriminator_VGG_128(nn.Module):
|
|||
|
||||
class Discriminator_VGG_128_GN(nn.Module):
|
||||
# input_img_factor = multiplier to support images over 128x128. Only certain factors are supported.
|
||||
def __init__(self, in_nc, nf, input_img_factor=1, do_checkpointing=False):
|
||||
def __init__(self, in_nc, nf, input_img_factor=1, do_checkpointing=False, extra_conv=False):
|
||||
super(Discriminator_VGG_128_GN, self).__init__()
|
||||
self.do_checkpointing = do_checkpointing
|
||||
|
||||
|
@ -111,6 +111,14 @@ class Discriminator_VGG_128_GN(nn.Module):
|
|||
self.bn4_0 = nn.GroupNorm(8, nf * 8, affine=True)
|
||||
self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
|
||||
self.bn4_1 = nn.GroupNorm(8, nf * 8, affine=True)
|
||||
|
||||
self.extra_conv = extra_conv
|
||||
if extra_conv:
|
||||
self.conv5_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
|
||||
self.bn5_0 = nn.GroupNorm(8, nf * 8, affine=True)
|
||||
self.conv5_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
|
||||
self.bn5_1 = nn.GroupNorm(8, nf * 8, affine=True)
|
||||
input_img_factor = input_img_factor // 2
|
||||
final_nf = nf * 8
|
||||
|
||||
# activation function
|
||||
|
@ -136,6 +144,10 @@ class Discriminator_VGG_128_GN(nn.Module):
|
|||
|
||||
fea = self.lrelu(self.bn4_0(self.conv4_0(fea)))
|
||||
fea = self.lrelu(self.bn4_1(self.conv4_1(fea)))
|
||||
|
||||
if self.extra_conv:
|
||||
fea = self.lrelu(self.bn5_0(self.conv5_0(fea)))
|
||||
fea = self.lrelu(self.bn5_1(self.conv5_1(fea)))
|
||||
return fea
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -196,7 +196,9 @@ def define_D_net(opt_net, img_sz=None, wrap=False):
|
|||
if which_model == 'discriminator_vgg_128':
|
||||
netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128, extra_conv=opt_net['extra_conv'])
|
||||
elif which_model == 'discriminator_vgg_128_gn':
|
||||
netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
|
||||
extra_conv = opt_net['extra_conv'] if 'extra_conv' in opt_net.keys() else False
|
||||
netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'],
|
||||
input_img_factor=img_sz / 128, extra_conv=extra_conv)
|
||||
if wrap:
|
||||
netD = GradDiscWrapper(netD)
|
||||
elif which_model == 'discriminator_vgg_128_gn_checkpointed':
|
||||
|
|
|
@ -56,6 +56,8 @@ def create_injector(opt_inject, env):
|
|||
return BatchRotateInjector(opt_inject, env)
|
||||
elif type == 'sr_diffs':
|
||||
return SrDiffsInjector(opt_inject, env)
|
||||
elif type == 'multiframe_combiner':
|
||||
return MultiFrameCombiner(opt_inject, env)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -419,3 +421,52 @@ class SrDiffsInjector(Injector):
|
|||
elif self.mode == 'recombine':
|
||||
combined = resampled_lq + hq
|
||||
return {self.output: combined}
|
||||
|
||||
|
||||
class MultiFrameCombiner(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
self.mode = opt['mode']
|
||||
self.dim = opt['dim'] if 'dim' in opt.keys() else None
|
||||
self.flow = opt['flow']
|
||||
self.in_lq_key = opt['in']
|
||||
self.in_hq_key = opt['in_hq']
|
||||
self.out_lq_key = opt['out']
|
||||
self.out_hq_key = opt['out_hq']
|
||||
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
|
||||
self.resampler = Resample2d()
|
||||
|
||||
def combine(self, state):
|
||||
flow = self.env['generators'][self.flow]
|
||||
lq = state[self.in_lq_key]
|
||||
hq = state[self.in_hq_key]
|
||||
b, f, c, h, w = lq.shape
|
||||
center = f // 2
|
||||
center_img = lq[:,center,:,:,:]
|
||||
imgs = [center_img]
|
||||
with torch.no_grad():
|
||||
for i in range(f):
|
||||
if i == center:
|
||||
continue
|
||||
nimg = lq[:,i,:,:,:]
|
||||
flowfield = flow(torch.stack([center_img, nimg], dim=2).float())
|
||||
nimg = self.resampler(nimg, flowfield)
|
||||
imgs.append(nimg)
|
||||
hq_out = hq[:,center,:,:,:]
|
||||
return {self.out_lq_key: torch.cat(imgs, dim=1),
|
||||
self.out_hq_key: hq_out,
|
||||
self.out_lq_key + "_flow_sample": torch.cat(imgs, dim=0)}
|
||||
|
||||
def synthesize(self, state):
|
||||
lq = state[self.in_lq_key]
|
||||
return {
|
||||
self.out_lq_key: lq.repeat(1, self.dim, 1, 1)
|
||||
}
|
||||
|
||||
def forward(self, state):
|
||||
if self.mode == "synthesize":
|
||||
return self.synthesize(state)
|
||||
elif self.mode == "combine":
|
||||
return self.combine(state)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -291,7 +291,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_exd_imgsetext_rrdb4x_6bl_2stride/train_exd_imgsetext_rrdb4x_6bl_2stride.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_rrdb4x_2stride_multiframe.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
Loading…
Reference in New Issue
Block a user