From 331c40f0c8a47d1a63a313078b70451cb3517f81 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 19 Oct 2020 15:23:04 -0600 Subject: [PATCH 1/2] Allow starting step to be forced Useful for testing purposes or to force a validation. --- codes/train.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/codes/train.py b/codes/train.py index efc6e378..4da72be3 100644 --- a/codes/train.py +++ b/codes/train.py @@ -170,6 +170,8 @@ def main(): else: current_step = -1 if 'start_step' not in opt.keys() else opt['start_step'] start_epoch = 0 + if 'force_start_step' in opt.keys(): + current_step = opt['force_start_step'] #### training logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step)) From 1b1ca297f868f70545926c62799d0e2b109477ed Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 19 Oct 2020 15:25:12 -0600 Subject: [PATCH 2/2] Fix recurrent=None bug in ChainedEmbeddingGen --- codes/models/archs/ChainedEmbeddingGen.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/codes/models/archs/ChainedEmbeddingGen.py b/codes/models/archs/ChainedEmbeddingGen.py index d82f00d7..4e454a27 100644 --- a/codes/models/archs/ChainedEmbeddingGen.py +++ b/codes/models/archs/ChainedEmbeddingGen.py @@ -71,6 +71,8 @@ class ChainedEmbeddingGenWithStructure(nn.Module): self.recurrent = recurrent self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False) if recurrent: + self.recurrent_nf = recurrent_nf + self.recurrent_stride = recurrent_stride self.recurrent_process = ConvGnLelu(recurrent_nf, 64, kernel_size=3, stride=recurrent_stride, norm=False, bias=True, activation=False) self.recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False) self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False) @@ -86,7 +88,12 @@ class ChainedEmbeddingGenWithStructure(nn.Module): fea = self.initial_conv(x) if self.recurrent: if recurrent is None: - recurrent = torch.zeros_like(fea) + if self.recurrent_nf == 3: + recurrent = torch.zeros_like(x) + if self.recurrent_stride != 1: + recurrent = torch.nn.functional.interpolate(recurrent, scale_factor=self.recurrent_stride, mode='nearest') + else: + recurrent = torch.zeros_like(fea) rec = self.recurrent_process(recurrent) fea, recstd = self.recurrent_join(fea, rec) self.ref_join_std = recstd.item() @@ -101,4 +108,4 @@ class ChainedEmbeddingGenWithStructure(nn.Module): return out, self.grad_extract(checkpoint(self.structure_upsample, grad)), self.grad_extract(out), fea def get_debug_values(self, step, net_name): - return { 'ref_join_std': self.ref_join_std } \ No newline at end of file + return { 'ref_join_std': self.ref_join_std }