ChainedGen: Output debugging information on blocks
This commit is contained in:
parent
b54de69153
commit
5753e77d67
|
@ -149,6 +149,8 @@ class StructuredChainedEmbeddingGenWithBypass(nn.Module):
|
|||
self.grad_extract = ImageGradientNoPadding()
|
||||
self.upsample = FinalUpsampleBlock2x(64)
|
||||
self.ref_join_std = 0
|
||||
self.block_residual_means = [0 for _ in range(depth)]
|
||||
self.block_residual_stds = [0 for _ in range(depth)]
|
||||
self.bypass_maps = []
|
||||
|
||||
def forward(self, x, recurrent=None):
|
||||
|
@ -172,6 +174,8 @@ class StructuredChainedEmbeddingGenWithBypass(nn.Module):
|
|||
residual, bypass_map = checkpoint(self.bypasses[i], residual, context)
|
||||
fea = fea + residual
|
||||
self.bypass_maps.append(bypass_map.detach())
|
||||
self.block_residual_means[i] = residual.mean().item()
|
||||
self.block_residual_stds[i] = residual.std().item()
|
||||
if i < 3:
|
||||
structure_br = checkpoint(self.structure_joins[i], grad, fea)
|
||||
grad = grad + checkpoint(self.structure_blocks[i], structure_br)
|
||||
|
@ -184,4 +188,9 @@ class StructuredChainedEmbeddingGenWithBypass(nn.Module):
|
|||
|
||||
def get_debug_values(self, step, net_name):
|
||||
biases = [b.bias.item() for b in self.bypasses]
|
||||
return { 'ref_join_std': self.ref_join_std, 'bypass_biases': sum(biases) / len(biases) }
|
||||
blk_stds, blk_means = {}, {}
|
||||
for i, (s, m) in enumerate(zip(self.block_residual_stds, self.block_residual_means)):
|
||||
blk_stds['block_%i' % (i+1,)] = s
|
||||
blk_means['block_%i' % (i+1,)] = m
|
||||
return {'ref_join_std': self.ref_join_std, 'bypass_biases': sum(biases) / len(biases),
|
||||
'blocks_std': blk_stds, 'blocks_mean': blk_means}
|
||||
|
|
|
@ -204,7 +204,7 @@ def main():
|
|||
_t = time()
|
||||
|
||||
#### log
|
||||
if current_step % opt['logger']['print_freq'] == 0:
|
||||
if current_step % opt['logger']['print_freq'] == 0 and rank <= 0:
|
||||
logs = model.get_current_log(current_step)
|
||||
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, current_step)
|
||||
for v in model.get_current_learning_rate():
|
||||
|
@ -212,16 +212,15 @@ def main():
|
|||
message += ')] '
|
||||
for k, v in logs.items():
|
||||
if 'histogram' in k:
|
||||
if rank <= 0:
|
||||
tb_logger.add_histogram(k, v, current_step)
|
||||
tb_logger.add_histogram(k, v, current_step)
|
||||
elif isinstance(v, dict):
|
||||
tb_logger.add_scalars(k, v, current_step)
|
||||
else:
|
||||
message += '{:s}: {:.4e} '.format(k, v)
|
||||
# tensorboard logger
|
||||
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
||||
if rank <= 0:
|
||||
tb_logger.add_scalar(k, v, current_step)
|
||||
if rank <= 0:
|
||||
logger.info(message)
|
||||
logger.info(message)
|
||||
|
||||
#### save models and training states
|
||||
if current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
||||
|
|
|
@ -204,7 +204,7 @@ def main():
|
|||
_t = time()
|
||||
|
||||
#### log
|
||||
if current_step % opt['logger']['print_freq'] == 0:
|
||||
if current_step % opt['logger']['print_freq'] == 0 and rank <= 0:
|
||||
logs = model.get_current_log(current_step)
|
||||
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, current_step)
|
||||
for v in model.get_current_learning_rate():
|
||||
|
@ -212,16 +212,15 @@ def main():
|
|||
message += ')] '
|
||||
for k, v in logs.items():
|
||||
if 'histogram' in k:
|
||||
if rank <= 0:
|
||||
tb_logger.add_histogram(k, v, current_step)
|
||||
tb_logger.add_histogram(k, v, current_step)
|
||||
elif isinstance(v, dict):
|
||||
tb_logger.add_scalars(k, v, current_step)
|
||||
else:
|
||||
message += '{:s}: {:.4e} '.format(k, v)
|
||||
# tensorboard logger
|
||||
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
||||
if rank <= 0:
|
||||
tb_logger.add_scalar(k, v, current_step)
|
||||
if rank <= 0:
|
||||
logger.info(message)
|
||||
logger.info(message)
|
||||
|
||||
#### save models and training states
|
||||
if current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
||||
|
|
Loading…
Reference in New Issue
Block a user