Tecogan & other fixes
This commit is contained in:
parent
b316078a15
commit
a3918fa808
|
@ -235,7 +235,7 @@ class ExtensibleTrainer(BaseModel):
|
||||||
if v not in state.keys():
|
if v not in state.keys():
|
||||||
continue # This can happen for several reasons (ex: 'after' defs), just ignore it.
|
continue # This can happen for several reasons (ex: 'after' defs), just ignore it.
|
||||||
for i, dbgv in enumerate(state[v]):
|
for i, dbgv in enumerate(state[v]):
|
||||||
if 'recurrent_visual_indices' in self.opt['logger'].keys():
|
if 'recurrent_visual_indices' in self.opt['logger'].keys() and len(dbgv.shape)==5:
|
||||||
for rvi in self.opt['logger']['recurrent_visual_indices']:
|
for rvi in self.opt['logger']['recurrent_visual_indices']:
|
||||||
rdbgv = dbgv[:, rvi]
|
rdbgv = dbgv[:, rvi]
|
||||||
if rdbgv.shape[1] > 3:
|
if rdbgv.shape[1] > 3:
|
||||||
|
|
|
@ -152,7 +152,8 @@ class RRDBNet(nn.Module):
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
first_conv_stride = 1 if in_channels <= 4 else scale
|
first_conv_stride = 1 if in_channels <= 4 else scale
|
||||||
first_conv_ksize = 3 if first_conv_stride == 1 else 7
|
first_conv_ksize = 3 if first_conv_stride == 1 else 7
|
||||||
self.conv_first = nn.Conv2d(in_channels, mid_channels, first_conv_ksize, first_conv_stride, 1)
|
first_conv_padding = 1 if first_conv_stride == 1 else 3
|
||||||
|
self.conv_first = nn.Conv2d(in_channels, mid_channels, first_conv_ksize, first_conv_stride, first_conv_padding)
|
||||||
self.body = make_layer(
|
self.body = make_layer(
|
||||||
body_block,
|
body_block,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
|
@ -186,7 +187,7 @@ class RRDBNet(nn.Module):
|
||||||
x_lg = F.interpolate(x, scale_factor=self.scale, mode="bicubic")
|
x_lg = F.interpolate(x, scale_factor=self.scale, mode="bicubic")
|
||||||
if ref is None:
|
if ref is None:
|
||||||
ref = torch.zeros_like(x_lg)
|
ref = torch.zeros_like(x_lg)
|
||||||
x_lg = torch.cat([x_lg, ref])
|
x_lg = torch.cat([x_lg, ref], dim=1)
|
||||||
feat = self.conv_first(x_lg)
|
feat = self.conv_first(x_lg)
|
||||||
body_feat = self.conv_body(checkpoint_sequential(self.body, self.num_blocks // self.blocks_per_checkpoint, feat))
|
body_feat = self.conv_body(checkpoint_sequential(self.body, self.num_blocks // self.blocks_per_checkpoint, feat))
|
||||||
feat = feat + body_feat
|
feat = feat + body_feat
|
||||||
|
|
|
@ -138,7 +138,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
final_results = {}
|
final_results = {}
|
||||||
# Include 'hq_batched' here - because why not... Don't really need a separate injector for this.
|
# Include 'hq_batched' here - because why not... Don't really need a separate injector for this.
|
||||||
b, s, c, h, w = state['hq'].shape
|
b, s, c, h, w = state['hq'].shape
|
||||||
final_results['hq_batched'] = state['hq'].view(b*s, c, h, w)
|
final_results['hq_batched'] = state['hq'].clone().permute(1,0,2,3,4).reshape(b*s, c, h, w)
|
||||||
for k, v in results.items():
|
for k, v in results.items():
|
||||||
final_results[k] = torch.stack(v, dim=1)
|
final_results[k] = torch.stack(v, dim=1)
|
||||||
final_results[k + "_batched"] = torch.cat(v[:s], dim=0) # Only include the original sequence - this output is generally used to compare against HQ.
|
final_results[k + "_batched"] = torch.cat(v[:s], dim=0) # Only include the original sequence - this output is generally used to compare against HQ.
|
||||||
|
@ -345,8 +345,10 @@ class PingPongLoss(ConfigurableLoss):
|
||||||
img_count = fake.shape[1]
|
img_count = fake.shape[1]
|
||||||
for i in range((img_count - 1) // 2):
|
for i in range((img_count - 1) // 2):
|
||||||
early = fake[:, i]
|
early = fake[:, i]
|
||||||
late = fake[:, -i]
|
late = fake[:, -(i+1)]
|
||||||
l_total += self.criterion(early, late)
|
l_total += self.criterion(early, late)
|
||||||
|
if self.env['step'] % 50 == 0:
|
||||||
|
self.produce_teco_visual_debugs2(early, late, i)
|
||||||
|
|
||||||
if self.env['step'] % 50 == 0:
|
if self.env['step'] % 50 == 0:
|
||||||
self.produce_teco_visual_debugs(fake)
|
self.produce_teco_visual_debugs(fake)
|
||||||
|
@ -363,3 +365,11 @@ class PingPongLoss(ConfigurableLoss):
|
||||||
img = imglist[:, i]
|
img = imglist[:, i]
|
||||||
torchvision.utils.save_image(img.float(), osp.join(base_path, "%s.png" % (i, )))
|
torchvision.utils.save_image(img.float(), osp.join(base_path, "%s.png" % (i, )))
|
||||||
|
|
||||||
|
def produce_teco_visual_debugs2(self, imga, imgb, i):
|
||||||
|
if self.env['rank'] > 0:
|
||||||
|
return
|
||||||
|
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step']))
|
||||||
|
os.makedirs(base_path, exist_ok=True)
|
||||||
|
torchvision.utils.save_image(imga.float(), osp.join(base_path, "%s_a.png" % (i, )))
|
||||||
|
torchvision.utils.save_image(imgb.float(), osp.join(base_path, "%s_b.png" % (i, )))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user