ChainedGen: Output debugging information on blocks

This commit is contained in:
James Betker 2020-10-21 16:36:23 -06:00
parent b54de69153
commit 5753e77d67
3 changed files with 20 additions and 13 deletions

View File

@ -149,6 +149,8 @@ class StructuredChainedEmbeddingGenWithBypass(nn.Module):
self.grad_extract = ImageGradientNoPadding()
self.upsample = FinalUpsampleBlock2x(64)
self.ref_join_std = 0
self.block_residual_means = [0 for _ in range(depth)]
self.block_residual_stds = [0 for _ in range(depth)]
self.bypass_maps = []
def forward(self, x, recurrent=None):
@ -172,6 +174,8 @@ class StructuredChainedEmbeddingGenWithBypass(nn.Module):
residual, bypass_map = checkpoint(self.bypasses[i], residual, context)
fea = fea + residual
self.bypass_maps.append(bypass_map.detach())
self.block_residual_means[i] = residual.mean().item()
self.block_residual_stds[i] = residual.std().item()
if i < 3:
structure_br = checkpoint(self.structure_joins[i], grad, fea)
grad = grad + checkpoint(self.structure_blocks[i], structure_br)
@ -184,4 +188,9 @@ class StructuredChainedEmbeddingGenWithBypass(nn.Module):
def get_debug_values(self, step, net_name):
biases = [b.bias.item() for b in self.bypasses]
return { 'ref_join_std': self.ref_join_std, 'bypass_biases': sum(biases) / len(biases) }
blk_stds, blk_means = {}, {}
for i, (s, m) in enumerate(zip(self.block_residual_stds, self.block_residual_means)):
blk_stds['block_%i' % (i+1,)] = s
blk_means['block_%i' % (i+1,)] = m
return {'ref_join_std': self.ref_join_std, 'bypass_biases': sum(biases) / len(biases),
'blocks_std': blk_stds, 'blocks_mean': blk_means}

View File

@ -204,7 +204,7 @@ def main():
_t = time()
#### log
if current_step % opt['logger']['print_freq'] == 0:
if current_step % opt['logger']['print_freq'] == 0 and rank <= 0:
logs = model.get_current_log(current_step)
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, current_step)
for v in model.get_current_learning_rate():
@ -212,16 +212,15 @@ def main():
message += ')] '
for k, v in logs.items():
if 'histogram' in k:
if rank <= 0:
tb_logger.add_histogram(k, v, current_step)
tb_logger.add_histogram(k, v, current_step)
elif isinstance(v, dict):
tb_logger.add_scalars(k, v, current_step)
else:
message += '{:s}: {:.4e} '.format(k, v)
# tensorboard logger
if opt['use_tb_logger'] and 'debug' not in opt['name']:
if rank <= 0:
tb_logger.add_scalar(k, v, current_step)
if rank <= 0:
logger.info(message)
logger.info(message)
#### save models and training states
if current_step % opt['logger']['save_checkpoint_freq'] == 0:

View File

@ -204,7 +204,7 @@ def main():
_t = time()
#### log
if current_step % opt['logger']['print_freq'] == 0:
if current_step % opt['logger']['print_freq'] == 0 and rank <= 0:
logs = model.get_current_log(current_step)
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, current_step)
for v in model.get_current_learning_rate():
@ -212,16 +212,15 @@ def main():
message += ')] '
for k, v in logs.items():
if 'histogram' in k:
if rank <= 0:
tb_logger.add_histogram(k, v, current_step)
tb_logger.add_histogram(k, v, current_step)
elif isinstance(v, dict):
tb_logger.add_scalars(k, v, current_step)
else:
message += '{:s}: {:.4e} '.format(k, v)
# tensorboard logger
if opt['use_tb_logger'] and 'debug' not in opt['name']:
if rank <= 0:
tb_logger.add_scalar(k, v, current_step)
if rank <= 0:
logger.info(message)
logger.info(message)
#### save models and training states
if current_step % opt['logger']['save_checkpoint_freq'] == 0: