Mods to tecogan to allow use of embeddings as input

This commit is contained in:
James Betker 2020-11-24 09:24:02 -07:00
parent b10bcf6436
commit f6098155cd
6 changed files with 127 additions and 53 deletions

View File

@ -6,7 +6,7 @@ import torch.nn.functional as F
import torchvision
from torch.utils.checkpoint import checkpoint_sequential
from models.archs.arch_util import make_layer, default_init_weights, ConvGnSilu
from models.archs.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu
class ResidualDenseBlock(nn.Module):
@ -60,11 +60,16 @@ class RRDB(nn.Module):
growth_channels (int): Channels for each growth.
"""
def __init__(self, mid_channels, growth_channels=32):
def __init__(self, mid_channels, growth_channels=32, reduce_to=None):
super(RRDB, self).__init__()
self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels)
self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels)
self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels)
if reduce_to is not None:
self.reducer = ConvGnLelu(mid_channels, reduce_to, kernel_size=3, activation=False, norm=False, bias=True)
self.recover_ch = mid_channels - reduce_to
else:
self.reducer = None
def forward(self, x):
"""Forward function.
@ -78,6 +83,10 @@ class RRDB(nn.Module):
out = self.rdb1(x)
out = self.rdb2(out)
out = self.rdb3(out)
if self.reducer is not None:
out = self.reducer(out)
b, f, h, w = out.shape
out = torch.cat([out, torch.zeros((b, self.recover_ch, h, w), device=out.device)], dim=1)
# Emperically, we use 0.2 to scale the residual for better performance
return out * 0.2 + x
@ -92,12 +101,19 @@ class RRDBWithBypass(nn.Module):
growth_channels (int): Channels for each growth.
"""
def __init__(self, mid_channels, growth_channels=32):
def __init__(self, mid_channels, growth_channels=32, reduce_to=None):
super(RRDBWithBypass, self).__init__()
self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels)
self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels)
self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels)
self.bypass = nn.Sequential(ConvGnSilu(mid_channels*2, mid_channels, kernel_size=3, bias=True, activation=True, norm=True),
if reduce_to is not None:
self.reducer = ConvGnLelu(mid_channels, reduce_to, kernel_size=3, activation=False, norm=False, bias=True)
self.recover_ch = mid_channels - reduce_to
bypass_channels = mid_channels + reduce_to
else:
self.reducer = None
bypass_channels = mid_channels * 2
self.bypass = nn.Sequential(ConvGnSilu(bypass_channels, mid_channels, kernel_size=3, bias=True, activation=True, norm=True),
ConvGnSilu(mid_channels, mid_channels//2, kernel_size=3, bias=False, activation=True, norm=False),
ConvGnSilu(mid_channels//2, 1, kernel_size=3, bias=False, activation=False, norm=False),
nn.Sigmoid())
@ -114,8 +130,15 @@ class RRDBWithBypass(nn.Module):
out = self.rdb1(x)
out = self.rdb2(out)
out = self.rdb3(out)
if self.reducer is not None:
out = self.reducer(out)
b, f, h, w = out.shape
out = torch.cat([out, torch.zeros((b, self.recover_ch, h, w), device=out.device)], dim=1)
bypass = self.bypass(torch.cat([x, out], dim=1))
self.bypass_map = bypass.detach().clone()
# Empirically, we use 0.2 to scale the residual for better performance
return out * 0.2 * bypass + x
@ -143,30 +166,45 @@ class RRDBNet(nn.Module):
num_blocks=23,
growth_channels=32,
body_block=RRDB,
blocks_per_checkpoint=4,
blocks_per_checkpoint=1,
scale=4,
additive_mode="not_additive" # Options: "not", "additive", "additive_enforced"
additive_mode="not", # Options: "not", "additive", "additive_enforced"
headless=False,
feature_channels=64, # Only applicable when headless=True. How many channels are used at the trunk level.
output_mode="hq_only", # Options: "hq_only", "hq+features", "features_only"
):
super(RRDBNet, self).__init__()
assert output_mode in ['hq_only', 'hq+features', 'features_only']
assert additive_mode in ['not', 'additive', 'additive_enforced']
self.num_blocks = num_blocks
self.blocks_per_checkpoint = blocks_per_checkpoint
self.scale = scale
self.in_channels = in_channels
self.output_mode = output_mode
first_conv_stride = 1 if in_channels <= 4 else scale
first_conv_ksize = 3 if first_conv_stride == 1 else 7
first_conv_padding = 1 if first_conv_stride == 1 else 3
self.conv_first = nn.Conv2d(in_channels, mid_channels, first_conv_ksize, first_conv_stride, first_conv_padding)
if headless:
self.conv_first = None
self.reduce_ch = feature_channels
reduce_to = feature_channels
self.conv_ref_first = ConvGnLelu(3, feature_channels, 7, stride=2, norm=False, activation=False, bias=True)
else:
self.conv_first = nn.Conv2d(in_channels, mid_channels, first_conv_ksize, first_conv_stride, first_conv_padding)
self.reduce_ch = mid_channels
reduce_to = None
self.body = make_layer(
body_block,
num_blocks,
mid_channels=mid_channels,
growth_channels=growth_channels)
self.conv_body = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
growth_channels=growth_channels,
reduce_to=reduce_to)
self.conv_body = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1)
# upsample
self.conv_up1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
self.conv_up2 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
self.conv_hr = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
self.conv_last = nn.Conv2d(mid_channels, out_channels, 3, 1, 1)
self.conv_up1 = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1)
self.conv_up2 = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1)
self.conv_hr = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1)
self.conv_last = nn.Conv2d(self.reduce_ch, out_channels, 3, 1, 1)
self.additive_mode = additive_mode
if additive_mode == "additive_enforced":
@ -178,7 +216,8 @@ class RRDBNet(nn.Module):
self.conv_first, self.conv_body, self.conv_up1,
self.conv_up2, self.conv_hr, self.conv_last
]:
default_init_weights(m, 0.1)
if m is not None:
default_init_weights(m, 0.1)
def forward(self, x, ref=None):
"""Forward function.
@ -189,25 +228,39 @@ class RRDBNet(nn.Module):
Returns:
Tensor: Forward results.
"""
if self.in_channels > 4:
x_lg = F.interpolate(x, scale_factor=self.scale, mode="bicubic")
if ref is None:
ref = torch.zeros_like(x_lg)
x_lg = torch.cat([x_lg, ref], dim=1)
if self.conv_first is None:
# Headless mode -> embedding inputs.
if ref is not None:
ref = self.conv_ref_first(ref)
feat = torch.cat([x, ref], dim=1)
else:
feat = x
else:
x_lg = x
feat = self.conv_first(x_lg)
body_feat = self.conv_body(checkpoint_sequential(self.body, self.num_blocks // self.blocks_per_checkpoint, feat))
# "Normal" mode -> image input.
if self.in_channels > 4:
x_lg = F.interpolate(x, scale_factor=self.scale, mode="bicubic")
if ref is None:
ref = torch.zeros_like(x_lg)
x_lg = torch.cat([x_lg, ref], dim=1)
else:
x_lg = x
feat = self.conv_first(x_lg)
feat = checkpoint_sequential(self.body, self.num_blocks // self.blocks_per_checkpoint, feat)
feat = feat[:, :self.reduce_ch]
body_feat = self.conv_body(feat)
feat = feat + body_feat
if self.output_mode == "features_only":
return feat
# upsample
feat = self.lrelu(
out = self.lrelu(
self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
if self.scale == 4:
feat = self.lrelu(
self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
out = self.lrelu(
self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest')))
else:
feat = self.lrelu(self.conv_up2(feat))
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
out = self.lrelu(self.conv_up2(out))
out = self.conv_last(self.lrelu(self.conv_hr(out)))
if "additive" in self.additive_mode:
x_interp = F.interpolate(x, scale_factor=self.scale, mode='bilinear')
if self.additive_mode == 'additive':
@ -216,9 +269,14 @@ class RRDBNet(nn.Module):
out_pooled = self.add_enforced_pool(out)
out = out - F.interpolate(out_pooled, scale_factor=self.scale, mode='nearest')
out = out + x_interp
if self.output_mode == "hq+features":
return out, feat
return out
def visual_dbg(self, step, path):
for i, bm in enumerate(self.body):
if hasattr(bm, 'bypass_map'):
torchvision.utils.save_image(bm.bypass_map.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1)))

View File

@ -39,14 +39,17 @@ def define_G(opt, opt_net, scale=None):
nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])
elif which_model == 'RRDBNet':
additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not_additive'
output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode)
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode,
output_mode=output_mode)
elif which_model == 'RRDBNetBypass':
additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not'
output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], body_block=RRDBNet_arch.RRDBWithBypass,
blocks_per_checkpoint=opt_net['blocks_per_checkpoint'], scale=opt_net['scale'],
additive_mode=additive_mode)
additive_mode=additive_mode, output_mode=output_mode)
elif which_model == 'rcan':
#args: n_resgroups, n_resblocks, res_scale, reduction, scale, n_feats
opt_net['rgb_range'] = 255
@ -110,8 +113,6 @@ def define_G(opt, opt_net, scale=None):
netG = SwitchedGen_arch.BackboneResnet()
elif which_model == "tecogen":
netG = TecoGen(opt_net['nf'], opt_net['scale'])
elif which_model == "basic_resampling_flow_predictor":
netG = BasicResamplingFlowNet(opt_net['nf'], resample_scale=opt_net['resample_scale'])
elif which_model == "rrdb_with_latent":
netG = RRDBNetWithLatent(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'],
@ -153,6 +154,11 @@ def define_G(opt, opt_net, scale=None):
netG = RRDBLatentWrapper(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], with_bypass=opt_net['with_bypass'],
blocks=opt_net['blocks_for_latent'], scale=opt_net['scale'], pretrain_rrdb_path=opt_net['pretrain_path'])
elif which_model == 'rrdb_centipede':
output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], scale=opt_net['scale'],
headless=True, output_mode=output_mode)
else:
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
return netG

View File

@ -285,6 +285,7 @@ class ForEachInjector(Injector):
o['in'] = '_in'
o['out'] = '_out'
self.injector = create_injector(o, self.env)
self.aslist = opt['aslist'] if 'aslist' in opt.keys() else False
def forward(self, state):
injs = []
@ -293,7 +294,10 @@ class ForEachInjector(Injector):
for i in range(inputs.shape[1]):
st['_in'] = inputs[:, i]
injs.append(self.injector(st)['_out'])
return {self.output: torch.stack(injs, dim=1)}
if self.aslist:
return {self.output: injs}
else:
return {self.output: torch.stack(injs, dim=1)}
class ConstantInjector(Injector):

View File

@ -140,7 +140,7 @@ class ConfigurableStep(Module):
# Don't do injections tagged with 'after' or 'before' when we are out of spec.
if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \
'before' in inj.opt.keys() and self.env['step'] > inj.opt['before'] or \
'every' in inj.opt.keys() and self.env['step'] % inj.opt['every'] != 0:
'every' in inj.opt.keys() and self.env['step'] % inj.opt['every'] != 0:
continue
injected = inj(local_state)
local_state.update(injected)

View File

@ -44,10 +44,12 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
super(RecurrentImageGeneratorSequenceInjector, self).__init__(opt, env)
self.flow = opt['flow_network']
self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0
self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0
self.recurrent_index = opt['recurrent_index']
self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0
self.output_recurrent_index = opt['output_recurrent_index'] if 'output_recurrent_index' in opt.keys() else self.output_hq_index
self.scale = opt['scale']
self.resample = Resample2d()
self.flow_key = opt['flow_input_key'] if 'flow_input_key' in opt.keys() else None
self.first_inputs = opt['first_inputs'] if 'first_inputs' in opt.keys() else opt['in'] # Use this to specify inputs that will be used in the first teco iteration, the rest will use 'in'.
self.do_backwards = opt['do_backwards'] if 'do_backwards' in opt.keys() else True
self.hq_recurrent = opt['hq_recurrent'] if 'hq_recurrent' in opt.keys() else False # When True, recurrent_index is not touched for the first iteration, allowing you to specify what is fed in. When False, zeros are fed into the recurrent index.
@ -82,20 +84,21 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
else:
input = extract_inputs_index(inputs, i)
with torch.no_grad() and autocast(enabled=False):
# This is a hack to workaround the fact that flownet2 cannot operate at resolutions < 64px. An assumption is
# made here that if you are operating at 4x scale, your inputs are 32px x 32px
if self.scale >= 4:
flow_input = F.interpolate(input[self.input_lq_index], scale_factor=self.scale//2, mode='bicubic')
if self.flow_key is not None:
flow_input = state[self.flow_key][:, i]
else:
flow_input = input[self.input_lq_index]
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=.5, mode='bicubic')
reduced_recurrent = F.interpolate(hq_recurrent, scale_factor=1/self.scale, mode='bicubic')
flow_input = torch.stack([flow_input, reduced_recurrent], dim=2).float()
flowfield = F.interpolate(flow(flow_input), scale_factor=2, mode='bicubic')
flowfield = flow(flow_input)
if recurrent_input.shape[-1] != flow_input.shape[-1]:
flowfield = F.interpolate(flowfield, scale_factor=self.scale, mode='bicubic')
recurrent_input = self.resample(recurrent_input.float(), flowfield)
input[self.recurrent_index] = recurrent_input
if self.env['step'] % 50 == 0:
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index)
debug_index += 1
if input[self.input_lq_index].shape[1] == 3: # Only debug this if we're dealing with images.
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.hq_recurrent], debug_index)
debug_index += 1
with autocast(enabled=self.env['opt']['fp16']):
gen_out = gen(*input)
@ -104,7 +107,8 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
gen_out = [gen_out]
for i, out_key in enumerate(self.output):
results[out_key].append(gen_out[i])
recurrent_input = gen_out[self.output_hq_index]
hq_recurrent = gen_out[self.output_hq_index]
recurrent_input = gen_out[self.output_recurrent_index]
# Now go backwards, skipping the last element (it's already stored in recurrent_input)
if self.do_backwards:
@ -113,20 +117,21 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
input = extract_inputs_index(inputs, i)
with torch.no_grad():
with autocast(enabled=False):
# This is a hack to workaround the fact that flownet2 cannot operate at resolutions < 64px. An assumption is
# made here that if you are operating at 4x scale, your inputs are 32px x 32px
if self.scale >= 4:
flow_input = F.interpolate(input[self.input_lq_index], scale_factor=self.scale//2, mode='bicubic')
if self.flow_key is not None:
flow_input = state[self.flow_key][:, i]
else:
flow_input = input[self.input_lq_index]
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=.5, mode='bicubic')
reduced_recurrent = F.interpolate(hq_recurrent, scale_factor=1/self.scale, mode='bicubic')
flow_input = torch.stack([flow_input, reduced_recurrent], dim=2).float()
flowfield = F.interpolate(flow(flow_input), scale_factor=2, mode='bicubic')
flowfield = flow(flow_input)
if recurrent_input.shape[-1] != flow_input.shape[-1]:
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
recurrent_input = self.resample(recurrent_input.float(), flowfield)
input[self.recurrent_index] = recurrent_input
if self.env['step'] % 50 == 0:
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index)
debug_index += 1
if input[self.input_lq_index].shape[1] == 3: # Only debug this if we're dealing with images.
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index)
debug_index += 1
with autocast(enabled=self.env['opt']['fp16']):
gen_out = gen(*input)
@ -135,7 +140,8 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
gen_out = [gen_out]
for i, out_key in enumerate(self.output):
results[out_key].append(gen_out[i])
recurrent_input = gen_out[self.output_hq_index]
hq_recurrent = gen_out[self.output_hq_index]
recurrent_input = gen_out[self.output_recurrent_index]
final_results = {}
# Include 'hq_batched' here - because why not... Don't really need a separate injector for this.

View File

@ -291,7 +291,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_for_sr_v2.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_using_rrdb_features.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()