extra_conv in gn discriminator, multiframe support in rrdb.

This commit is contained in:
James Betker 2020-11-29 15:39:50 -07:00
parent da604752e6
commit 1e0f69e34b
5 changed files with 72 additions and 5 deletions

View File

@ -177,6 +177,7 @@ class RRDBNet(nn.Module):
feature_channels=64, # Only applicable when headless=True. How many channels are used at the trunk level.
output_mode="hq_only", # Options: "hq_only", "hq+features", "features_only"
initial_stride=1,
use_ref=False, # When set, a reference image is expected as input and synthesized if not found. Useful for video SR.
):
super(RRDBNet, self).__init__()
assert output_mode in ['hq_only', 'hq+features', 'features_only']
@ -186,7 +187,8 @@ class RRDBNet(nn.Module):
self.scale = scale
self.in_channels = in_channels
self.output_mode = output_mode
first_conv_stride = initial_stride if in_channels <= 4 else scale
self.use_ref = use_ref
first_conv_stride = initial_stride if not self.use_ref else scale
first_conv_ksize = 3 if first_conv_stride == 1 else 7
first_conv_padding = 1 if first_conv_stride == 1 else 3
if headless:
@ -242,7 +244,7 @@ class RRDBNet(nn.Module):
feat = x
else:
# "Normal" mode -> image input.
if self.in_channels > 4:
if self.use_ref:
x_lg = F.interpolate(x, scale_factor=self.scale, mode="bicubic")
if ref is None:
ref = torch.zeros_like(x_lg)

View File

@ -83,7 +83,7 @@ class Discriminator_VGG_128(nn.Module):
class Discriminator_VGG_128_GN(nn.Module):
# input_img_factor = multiplier to support images over 128x128. Only certain factors are supported.
def __init__(self, in_nc, nf, input_img_factor=1, do_checkpointing=False):
def __init__(self, in_nc, nf, input_img_factor=1, do_checkpointing=False, extra_conv=False):
super(Discriminator_VGG_128_GN, self).__init__()
self.do_checkpointing = do_checkpointing
@ -111,6 +111,14 @@ class Discriminator_VGG_128_GN(nn.Module):
self.bn4_0 = nn.GroupNorm(8, nf * 8, affine=True)
self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
self.bn4_1 = nn.GroupNorm(8, nf * 8, affine=True)
self.extra_conv = extra_conv
if extra_conv:
self.conv5_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
self.bn5_0 = nn.GroupNorm(8, nf * 8, affine=True)
self.conv5_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
self.bn5_1 = nn.GroupNorm(8, nf * 8, affine=True)
input_img_factor = input_img_factor // 2
final_nf = nf * 8
# activation function
@ -136,6 +144,10 @@ class Discriminator_VGG_128_GN(nn.Module):
fea = self.lrelu(self.bn4_0(self.conv4_0(fea)))
fea = self.lrelu(self.bn4_1(self.conv4_1(fea)))
if self.extra_conv:
fea = self.lrelu(self.bn5_0(self.conv5_0(fea)))
fea = self.lrelu(self.bn5_1(self.conv5_1(fea)))
return fea
def forward(self, x):

View File

@ -196,7 +196,9 @@ def define_D_net(opt_net, img_sz=None, wrap=False):
if which_model == 'discriminator_vgg_128':
netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128, extra_conv=opt_net['extra_conv'])
elif which_model == 'discriminator_vgg_128_gn':
netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
extra_conv = opt_net['extra_conv'] if 'extra_conv' in opt_net.keys() else False
netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'],
input_img_factor=img_sz / 128, extra_conv=extra_conv)
if wrap:
netD = GradDiscWrapper(netD)
elif which_model == 'discriminator_vgg_128_gn_checkpointed':

View File

@ -56,6 +56,8 @@ def create_injector(opt_inject, env):
return BatchRotateInjector(opt_inject, env)
elif type == 'sr_diffs':
return SrDiffsInjector(opt_inject, env)
elif type == 'multiframe_combiner':
return MultiFrameCombiner(opt_inject, env)
else:
raise NotImplementedError
@ -419,3 +421,52 @@ class SrDiffsInjector(Injector):
elif self.mode == 'recombine':
combined = resampled_lq + hq
return {self.output: combined}
class MultiFrameCombiner(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
self.mode = opt['mode']
self.dim = opt['dim'] if 'dim' in opt.keys() else None
self.flow = opt['flow']
self.in_lq_key = opt['in']
self.in_hq_key = opt['in_hq']
self.out_lq_key = opt['out']
self.out_hq_key = opt['out_hq']
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
self.resampler = Resample2d()
def combine(self, state):
flow = self.env['generators'][self.flow]
lq = state[self.in_lq_key]
hq = state[self.in_hq_key]
b, f, c, h, w = lq.shape
center = f // 2
center_img = lq[:,center,:,:,:]
imgs = [center_img]
with torch.no_grad():
for i in range(f):
if i == center:
continue
nimg = lq[:,i,:,:,:]
flowfield = flow(torch.stack([center_img, nimg], dim=2).float())
nimg = self.resampler(nimg, flowfield)
imgs.append(nimg)
hq_out = hq[:,center,:,:,:]
return {self.out_lq_key: torch.cat(imgs, dim=1),
self.out_hq_key: hq_out,
self.out_lq_key + "_flow_sample": torch.cat(imgs, dim=0)}
def synthesize(self, state):
lq = state[self.in_lq_key]
return {
self.out_lq_key: lq.repeat(1, self.dim, 1, 1)
}
def forward(self, state):
if self.mode == "synthesize":
return self.synthesize(state)
elif self.mode == "combine":
return self.combine(state)
else:
raise NotImplementedError

View File

@ -291,7 +291,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_exd_imgsetext_rrdb4x_6bl_2stride/train_exd_imgsetext_rrdb4x_6bl_2stride.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_rrdb4x_2stride_multiframe.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()