Tecogan & other fixes

This commit is contained in:
James Betker 2020-10-30 00:19:58 -06:00
parent b316078a15
commit a3918fa808
3 changed files with 16 additions and 5 deletions

View File

@ -235,7 +235,7 @@ class ExtensibleTrainer(BaseModel):
if v not in state.keys():
continue # This can happen for several reasons (ex: 'after' defs), just ignore it.
for i, dbgv in enumerate(state[v]):
if 'recurrent_visual_indices' in self.opt['logger'].keys():
if 'recurrent_visual_indices' in self.opt['logger'].keys() and len(dbgv.shape)==5:
for rvi in self.opt['logger']['recurrent_visual_indices']:
rdbgv = dbgv[:, rvi]
if rdbgv.shape[1] > 3:

View File

@ -152,7 +152,8 @@ class RRDBNet(nn.Module):
self.in_channels = in_channels
first_conv_stride = 1 if in_channels <= 4 else scale
first_conv_ksize = 3 if first_conv_stride == 1 else 7
self.conv_first = nn.Conv2d(in_channels, mid_channels, first_conv_ksize, first_conv_stride, 1)
first_conv_padding = 1 if first_conv_stride == 1 else 3
self.conv_first = nn.Conv2d(in_channels, mid_channels, first_conv_ksize, first_conv_stride, first_conv_padding)
self.body = make_layer(
body_block,
num_blocks,
@ -186,7 +187,7 @@ class RRDBNet(nn.Module):
x_lg = F.interpolate(x, scale_factor=self.scale, mode="bicubic")
if ref is None:
ref = torch.zeros_like(x_lg)
x_lg = torch.cat([x_lg, ref])
x_lg = torch.cat([x_lg, ref], dim=1)
feat = self.conv_first(x_lg)
body_feat = self.conv_body(checkpoint_sequential(self.body, self.num_blocks // self.blocks_per_checkpoint, feat))
feat = feat + body_feat

View File

@ -138,7 +138,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
final_results = {}
# Include 'hq_batched' here - because why not... Don't really need a separate injector for this.
b, s, c, h, w = state['hq'].shape
final_results['hq_batched'] = state['hq'].view(b*s, c, h, w)
final_results['hq_batched'] = state['hq'].clone().permute(1,0,2,3,4).reshape(b*s, c, h, w)
for k, v in results.items():
final_results[k] = torch.stack(v, dim=1)
final_results[k + "_batched"] = torch.cat(v[:s], dim=0) # Only include the original sequence - this output is generally used to compare against HQ.
@ -345,8 +345,10 @@ class PingPongLoss(ConfigurableLoss):
img_count = fake.shape[1]
for i in range((img_count - 1) // 2):
early = fake[:, i]
late = fake[:, -i]
late = fake[:, -(i+1)]
l_total += self.criterion(early, late)
if self.env['step'] % 50 == 0:
self.produce_teco_visual_debugs2(early, late, i)
if self.env['step'] % 50 == 0:
self.produce_teco_visual_debugs(fake)
@ -363,3 +365,11 @@ class PingPongLoss(ConfigurableLoss):
img = imglist[:, i]
torchvision.utils.save_image(img.float(), osp.join(base_path, "%s.png" % (i, )))
def produce_teco_visual_debugs2(self, imga, imgb, i):
if self.env['rank'] > 0:
return
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step']))
os.makedirs(base_path, exist_ok=True)
torchvision.utils.save_image(imga.float(), osp.join(base_path, "%s_a.png" % (i, )))
torchvision.utils.save_image(imgb.float(), osp.join(base_path, "%s_b.png" % (i, )))