From e7850299365c4811ec46c56593ed72e102dab009 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 10 Oct 2020 22:39:55 -0600 Subject: [PATCH] Mods needed to support SPSR archs with teco gan --- codes/models/ExtensibleTrainer.py | 16 ++++++++++---- codes/models/archs/SPSR_arch.py | 15 ++++++++++++-- codes/models/networks.py | 3 ++- codes/models/steps/injectors.py | 31 +++++++++++++++------------- codes/models/steps/losses.py | 4 ++-- codes/models/steps/tecogan_losses.py | 21 +++++++++++++------ codes/train2.py | 2 +- 7 files changed, 62 insertions(+), 30 deletions(-) diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 51154fd0..314f0b07 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -234,10 +234,18 @@ class ExtensibleTrainer(BaseModel): continue # This can happen for several reasons (ex: 'after' defs), just ignore it. if step % self.opt['logger']['visual_debug_rate'] == 0: for i, dbgv in enumerate(state[v]): - if dbgv.shape[1] > 3: - dbgv = dbgv[:,:3,:,:] - os.makedirs(os.path.join(sample_save_path, v), exist_ok=True) - utils.save_image(dbgv, os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i))) + if 'recurrent_visual_indices' in self.opt['logger'].keys(): + for rvi in self.opt['logger']['recurrent_visual_indices']: + rdbgv = dbgv[:, rvi] + if rdbgv.shape[1] > 3: + rdbgv = rdbgv[:, :3, :, :] + os.makedirs(os.path.join(sample_save_path, v), exist_ok=True) + utils.save_image(rdbgv, os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i))) + else: + if dbgv.shape[1] > 3: + dbgv = dbgv[:,:3,:,:] + os.makedirs(os.path.join(sample_save_path, v), exist_ok=True) + utils.save_image(dbgv, os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i))) def compute_fea_loss(self, real, fake): with torch.no_grad(): diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index 26240514..b822f1a4 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -460,13 +460,19 @@ class Spsr6(nn.Module): # Variant of Spsr6 which uses multiplexer blocks that feed off of a reference embedding. Also computes that embedding. class Spsr7(nn.Module): - def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, multiplexer_reductions=3, init_temperature=10): + def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, multiplexer_reductions=3, recurrent=False, init_temperature=10): super(Spsr7, self).__init__() n_upscale = int(math.log(upscale, 2)) # processing the input embedding self.reference_embedding = ReferenceImageBranch(nf) + self.recurrent = recurrent + if recurrent: + self.model_recurrent_conv = ConvGnLelu(3, nf, kernel_size=3, stride=2, norm=False, activation=False, + bias=True) + self.model_fea_recurrent_combine = ConvGnLelu(nf * 2, nf, 1, activation=False, norm=False, bias=False, weight_init_factor=.01) + # switch options self.nf = nf transformation_filters = nf @@ -522,7 +528,7 @@ class Spsr7(nn.Module): self.final_temperature_step = 10000 self.lr = None - def forward(self, x, ref, ref_center, update_attention_norm=True): + def forward(self, x, ref, ref_center, update_attention_norm=True, recurrent=None): # The attention_maps debugger outputs . Save that here. self.lr = x.detach().cpu() @@ -531,6 +537,11 @@ class Spsr7(nn.Module): ref_embedding = ref_code.view(-1, self.nf * 8, 1, 1).repeat(1, 1, x.shape[2] // 8, x.shape[3] // 8) x = self.model_fea_conv(x) + if self.recurrent: + rec = self.model_recurrent_conv(recurrent) + br = self.model_fea_recurrent_combine(torch.cat([x, rec], dim=1)) + x = x + br + x1 = x x1, a1 = self.sw1(x1, True, identity=x, att_in=(x1, ref_embedding)) diff --git a/codes/models/networks.py b/codes/models/networks.py index f354f582..6e34ac6b 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -71,10 +71,11 @@ def define_G(opt, net_key='network_G', scale=None): multiplexer_reductions=opt_net['multiplexer_reductions'] if 'multiplexer_reductions' in opt_net.keys() else 3, init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) elif which_model == "spsr7": + recurrent = opt_net['recurrent'] if 'recurrent' in opt_net.keys() else False xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 netG = spsr.Spsr7(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], multiplexer_reductions=opt_net['multiplexer_reductions'] if 'multiplexer_reductions' in opt_net.keys() else 3, - init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) + init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10, recurrent=recurrent) elif which_model == "spsr9": xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 netG = spsr.Spsr9(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index f0d26185..9ef6fb1e 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -34,8 +34,8 @@ def create_injector(opt_inject, env): return ConcatenateInjector(opt_inject, env) elif type == 'margin_removal': return MarginRemoval(opt_inject, env) - elif type == 'constant': - return ConstantInjector(opt_inject, env) + elif type == 'foreach': + return ForEachInjector(opt_inject, env) else: raise NotImplementedError @@ -221,18 +221,21 @@ class MarginRemoval(Injector): return {self.opt['out']: input[:, :, self.margin:-self.margin, self.margin:-self.margin]} -class ConstantInjector(Injector): +# Produces an injection which is composed of applying a single injector multiple times across a single dimension. +class ForEachInjector(Injector): def __init__(self, opt, env): - super(ConstantInjector, self).__init__(opt, env) - self.constant_type = opt['constant_type'] - self.dim = opt['dim'] - self.like = opt['like'] # This injector uses this tensor to determine what batch size and device to use. + super(ForEachInjector, self).__init__(opt, env) + o = opt.copy() + o['type'] = opt['subtype'] + o['in'] = '_in' + o['out'] = '_out' + self.injector = create_injector(o, self.env) def forward(self, state): - bs = state[self.like].shape[0] - dev = state[self.like].device - if self.constant_type == 'zeros': - out = torch.zeros((bs,) + tuple(self.dim), device=dev) - else: - raise NotImplementedError - return { self.opt['out']: out } + injs = [] + st = state.copy() + inputs = state[self.opt['in']] + for i in range(inputs.shape[1]): + st['_in'] = inputs[:, i] + injs.append(self.injector(st)['_out']) + return {self.output: torch.stack(injs, dim=1)} diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index e2f60783..3e6eeb6f 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -394,7 +394,7 @@ class RecurrentLoss(ConfigurableLoss): real = state[self.opt['real']] for i in range(real.shape[1]): st['_real'] = real[:, i] - st['_fake'] = state[self.opt['fake']][i] + st['_fake'] = state[self.opt['fake']][:, i] total_loss += self.loss(net, st) return total_loss @@ -413,5 +413,5 @@ class ForElementLoss(ConfigurableLoss): def forward(self, net, state): st = state.copy() st['_real'] = state[self.opt['real']][:, self.index] - st['_fake'] = state[self.opt['fake']][self.index] + st['_fake'] = state[self.opt['fake']][:, self.index] return self.loss(net, st) diff --git a/codes/models/steps/tecogan_losses.py b/codes/models/steps/tecogan_losses.py index 7f50f3e9..e2717d94 100644 --- a/codes/models/steps/tecogan_losses.py +++ b/codes/models/steps/tecogan_losses.py @@ -61,7 +61,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector): super(RecurrentImageGeneratorSequenceInjector, self).__init__(opt, env) self.flow = opt['flow_network'] self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0 - self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0 + self.output_hq_index = opt['output_hq_index'] if 'output_index' in opt.keys() else 0 self.recurrent_index = opt['recurrent_index'] self.scale = opt['scale'] self.resample = Resample2d() @@ -71,12 +71,17 @@ class RecurrentImageGeneratorSequenceInjector(Injector): def forward(self, state): gen = self.env['generators'][self.opt['generator']] flow = self.env['generators'][self.flow] - results = [] first_inputs = extract_params_from_state(self.first_inputs, state) inputs = extract_params_from_state(self.input, state) if not isinstance(inputs, list): inputs = [inputs] + if not isinstance(self.output, list): + self.output = [self.output] + results = {} + for out_key in self.output: + results[out_key] = [] + # Go forward in the sequence first. first_step = True b, f, c, h, w = inputs[self.input_lq_index].shape @@ -101,8 +106,9 @@ class RecurrentImageGeneratorSequenceInjector(Injector): gen_out = gen(*input) if isinstance(gen_out, torch.Tensor): gen_out = [gen_out] + for i, out_key in enumerate(self.output): + results[out_key].append(gen_out[i]) recurrent_input = gen_out[self.output_hq_index] - results.append(recurrent_input) # Now go backwards, skipping the last element (it's already stored in recurrent_input) if self.do_backwards: @@ -122,10 +128,13 @@ class RecurrentImageGeneratorSequenceInjector(Injector): gen_out = gen(*input) if isinstance(gen_out, torch.Tensor): gen_out = [gen_out] + for i, out_key in enumerate(self.output): + results[out_key].append(gen_out[i]) recurrent_input = gen_out[self.output_hq_index] - results.append(recurrent_input) - return {self.output: results} + for k, v in results.items(): + results[k] = torch.stack(v, dim=1) + return results def produce_teco_visual_debugs(self, gen_input, gen_recurrent, it): if self.env['rank'] > 0: @@ -183,7 +192,7 @@ class TecoGanLoss(ConfigurableLoss): net = self.env['discriminators'][self.opt['discriminator']] flow_gen = self.env['generators'][self.image_flow_generator] real = state[self.opt['real']] - fake = torch.stack(state[self.opt['fake']], dim=1) + fake = state[self.opt['fake']] sequence_len = real.shape[1] lr = state[self.opt['lr_inputs']] l_total = 0 diff --git a/codes/train2.py b/codes/train2.py index 404bdfd4..5aa05dc8 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_pretrain_ssgteco.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_spsr7.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()