From 5753e77d67b40db0c3009d2117dc511250e2d32d Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 21 Oct 2020 16:36:23 -0600 Subject: [PATCH] ChainedGen: Output debugging information on blocks --- codes/models/archs/ChainedEmbeddingGen.py | 11 ++++++++++- codes/train.py | 11 +++++------ codes/train2.py | 11 +++++------ 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/codes/models/archs/ChainedEmbeddingGen.py b/codes/models/archs/ChainedEmbeddingGen.py index fd1e76fc..34e6b5b3 100644 --- a/codes/models/archs/ChainedEmbeddingGen.py +++ b/codes/models/archs/ChainedEmbeddingGen.py @@ -149,6 +149,8 @@ class StructuredChainedEmbeddingGenWithBypass(nn.Module): self.grad_extract = ImageGradientNoPadding() self.upsample = FinalUpsampleBlock2x(64) self.ref_join_std = 0 + self.block_residual_means = [0 for _ in range(depth)] + self.block_residual_stds = [0 for _ in range(depth)] self.bypass_maps = [] def forward(self, x, recurrent=None): @@ -172,6 +174,8 @@ class StructuredChainedEmbeddingGenWithBypass(nn.Module): residual, bypass_map = checkpoint(self.bypasses[i], residual, context) fea = fea + residual self.bypass_maps.append(bypass_map.detach()) + self.block_residual_means[i] = residual.mean().item() + self.block_residual_stds[i] = residual.std().item() if i < 3: structure_br = checkpoint(self.structure_joins[i], grad, fea) grad = grad + checkpoint(self.structure_blocks[i], structure_br) @@ -184,4 +188,9 @@ class StructuredChainedEmbeddingGenWithBypass(nn.Module): def get_debug_values(self, step, net_name): biases = [b.bias.item() for b in self.bypasses] - return { 'ref_join_std': self.ref_join_std, 'bypass_biases': sum(biases) / len(biases) } + blk_stds, blk_means = {}, {} + for i, (s, m) in enumerate(zip(self.block_residual_stds, self.block_residual_means)): + blk_stds['block_%i' % (i+1,)] = s + blk_means['block_%i' % (i+1,)] = m + return {'ref_join_std': self.ref_join_std, 'bypass_biases': sum(biases) / len(biases), + 'blocks_std': blk_stds, 'blocks_mean': blk_means} diff --git a/codes/train.py b/codes/train.py index 122ce55f..931a4c95 100644 --- a/codes/train.py +++ b/codes/train.py @@ -204,7 +204,7 @@ def main(): _t = time() #### log - if current_step % opt['logger']['print_freq'] == 0: + if current_step % opt['logger']['print_freq'] == 0 and rank <= 0: logs = model.get_current_log(current_step) message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, current_step) for v in model.get_current_learning_rate(): @@ -212,16 +212,15 @@ def main(): message += ')] ' for k, v in logs.items(): if 'histogram' in k: - if rank <= 0: - tb_logger.add_histogram(k, v, current_step) + tb_logger.add_histogram(k, v, current_step) + elif isinstance(v, dict): + tb_logger.add_scalars(k, v, current_step) else: message += '{:s}: {:.4e} '.format(k, v) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: - if rank <= 0: tb_logger.add_scalar(k, v, current_step) - if rank <= 0: - logger.info(message) + logger.info(message) #### save models and training states if current_step % opt['logger']['save_checkpoint_freq'] == 0: diff --git a/codes/train2.py b/codes/train2.py index 71517fa4..e6348a72 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -204,7 +204,7 @@ def main(): _t = time() #### log - if current_step % opt['logger']['print_freq'] == 0: + if current_step % opt['logger']['print_freq'] == 0 and rank <= 0: logs = model.get_current_log(current_step) message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, current_step) for v in model.get_current_learning_rate(): @@ -212,16 +212,15 @@ def main(): message += ')] ' for k, v in logs.items(): if 'histogram' in k: - if rank <= 0: - tb_logger.add_histogram(k, v, current_step) + tb_logger.add_histogram(k, v, current_step) + elif isinstance(v, dict): + tb_logger.add_scalars(k, v, current_step) else: message += '{:s}: {:.4e} '.format(k, v) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name']: - if rank <= 0: tb_logger.add_scalar(k, v, current_step) - if rank <= 0: - logger.info(message) + logger.info(message) #### save models and training states if current_step % opt['logger']['save_checkpoint_freq'] == 0: