diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 68fca1e6..d8485273 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -235,7 +235,7 @@ class ExtensibleTrainer(BaseModel): if v not in state.keys(): continue # This can happen for several reasons (ex: 'after' defs), just ignore it. for i, dbgv in enumerate(state[v]): - if 'recurrent_visual_indices' in self.opt['logger'].keys(): + if 'recurrent_visual_indices' in self.opt['logger'].keys() and len(dbgv.shape)==5: for rvi in self.opt['logger']['recurrent_visual_indices']: rdbgv = dbgv[:, rvi] if rdbgv.shape[1] > 3: diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py index 8b763b09..42a04770 100644 --- a/codes/models/archs/RRDBNet_arch.py +++ b/codes/models/archs/RRDBNet_arch.py @@ -152,7 +152,8 @@ class RRDBNet(nn.Module): self.in_channels = in_channels first_conv_stride = 1 if in_channels <= 4 else scale first_conv_ksize = 3 if first_conv_stride == 1 else 7 - self.conv_first = nn.Conv2d(in_channels, mid_channels, first_conv_ksize, first_conv_stride, 1) + first_conv_padding = 1 if first_conv_stride == 1 else 3 + self.conv_first = nn.Conv2d(in_channels, mid_channels, first_conv_ksize, first_conv_stride, first_conv_padding) self.body = make_layer( body_block, num_blocks, @@ -186,7 +187,7 @@ class RRDBNet(nn.Module): x_lg = F.interpolate(x, scale_factor=self.scale, mode="bicubic") if ref is None: ref = torch.zeros_like(x_lg) - x_lg = torch.cat([x_lg, ref]) + x_lg = torch.cat([x_lg, ref], dim=1) feat = self.conv_first(x_lg) body_feat = self.conv_body(checkpoint_sequential(self.body, self.num_blocks // self.blocks_per_checkpoint, feat)) feat = feat + body_feat diff --git a/codes/models/steps/tecogan_losses.py b/codes/models/steps/tecogan_losses.py index 7a09cfc1..2da702ca 100644 --- a/codes/models/steps/tecogan_losses.py +++ b/codes/models/steps/tecogan_losses.py @@ -138,7 +138,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector): final_results = {} # Include 'hq_batched' here - because why not... Don't really need a separate injector for this. b, s, c, h, w = state['hq'].shape - final_results['hq_batched'] = state['hq'].view(b*s, c, h, w) + final_results['hq_batched'] = state['hq'].clone().permute(1,0,2,3,4).reshape(b*s, c, h, w) for k, v in results.items(): final_results[k] = torch.stack(v, dim=1) final_results[k + "_batched"] = torch.cat(v[:s], dim=0) # Only include the original sequence - this output is generally used to compare against HQ. @@ -345,8 +345,10 @@ class PingPongLoss(ConfigurableLoss): img_count = fake.shape[1] for i in range((img_count - 1) // 2): early = fake[:, i] - late = fake[:, -i] + late = fake[:, -(i+1)] l_total += self.criterion(early, late) + if self.env['step'] % 50 == 0: + self.produce_teco_visual_debugs2(early, late, i) if self.env['step'] % 50 == 0: self.produce_teco_visual_debugs(fake) @@ -363,3 +365,11 @@ class PingPongLoss(ConfigurableLoss): img = imglist[:, i] torchvision.utils.save_image(img.float(), osp.join(base_path, "%s.png" % (i, ))) + def produce_teco_visual_debugs2(self, imga, imgb, i): + if self.env['rank'] > 0: + return + base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step'])) + os.makedirs(base_path, exist_ok=True) + torchvision.utils.save_image(imga.float(), osp.join(base_path, "%s_a.png" % (i, ))) + torchvision.utils.save_image(imgb.float(), osp.join(base_path, "%s_b.png" % (i, ))) +