forked from mrq/DL-Art-School
ChainedGen: Output debugging information on blocks
This commit is contained in:
parent
b54de69153
commit
5753e77d67
|
@ -149,6 +149,8 @@ class StructuredChainedEmbeddingGenWithBypass(nn.Module):
|
||||||
self.grad_extract = ImageGradientNoPadding()
|
self.grad_extract = ImageGradientNoPadding()
|
||||||
self.upsample = FinalUpsampleBlock2x(64)
|
self.upsample = FinalUpsampleBlock2x(64)
|
||||||
self.ref_join_std = 0
|
self.ref_join_std = 0
|
||||||
|
self.block_residual_means = [0 for _ in range(depth)]
|
||||||
|
self.block_residual_stds = [0 for _ in range(depth)]
|
||||||
self.bypass_maps = []
|
self.bypass_maps = []
|
||||||
|
|
||||||
def forward(self, x, recurrent=None):
|
def forward(self, x, recurrent=None):
|
||||||
|
@ -172,6 +174,8 @@ class StructuredChainedEmbeddingGenWithBypass(nn.Module):
|
||||||
residual, bypass_map = checkpoint(self.bypasses[i], residual, context)
|
residual, bypass_map = checkpoint(self.bypasses[i], residual, context)
|
||||||
fea = fea + residual
|
fea = fea + residual
|
||||||
self.bypass_maps.append(bypass_map.detach())
|
self.bypass_maps.append(bypass_map.detach())
|
||||||
|
self.block_residual_means[i] = residual.mean().item()
|
||||||
|
self.block_residual_stds[i] = residual.std().item()
|
||||||
if i < 3:
|
if i < 3:
|
||||||
structure_br = checkpoint(self.structure_joins[i], grad, fea)
|
structure_br = checkpoint(self.structure_joins[i], grad, fea)
|
||||||
grad = grad + checkpoint(self.structure_blocks[i], structure_br)
|
grad = grad + checkpoint(self.structure_blocks[i], structure_br)
|
||||||
|
@ -184,4 +188,9 @@ class StructuredChainedEmbeddingGenWithBypass(nn.Module):
|
||||||
|
|
||||||
def get_debug_values(self, step, net_name):
|
def get_debug_values(self, step, net_name):
|
||||||
biases = [b.bias.item() for b in self.bypasses]
|
biases = [b.bias.item() for b in self.bypasses]
|
||||||
return { 'ref_join_std': self.ref_join_std, 'bypass_biases': sum(biases) / len(biases) }
|
blk_stds, blk_means = {}, {}
|
||||||
|
for i, (s, m) in enumerate(zip(self.block_residual_stds, self.block_residual_means)):
|
||||||
|
blk_stds['block_%i' % (i+1,)] = s
|
||||||
|
blk_means['block_%i' % (i+1,)] = m
|
||||||
|
return {'ref_join_std': self.ref_join_std, 'bypass_biases': sum(biases) / len(biases),
|
||||||
|
'blocks_std': blk_stds, 'blocks_mean': blk_means}
|
||||||
|
|
|
@ -204,7 +204,7 @@ def main():
|
||||||
_t = time()
|
_t = time()
|
||||||
|
|
||||||
#### log
|
#### log
|
||||||
if current_step % opt['logger']['print_freq'] == 0:
|
if current_step % opt['logger']['print_freq'] == 0 and rank <= 0:
|
||||||
logs = model.get_current_log(current_step)
|
logs = model.get_current_log(current_step)
|
||||||
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, current_step)
|
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, current_step)
|
||||||
for v in model.get_current_learning_rate():
|
for v in model.get_current_learning_rate():
|
||||||
|
@ -212,15 +212,14 @@ def main():
|
||||||
message += ')] '
|
message += ')] '
|
||||||
for k, v in logs.items():
|
for k, v in logs.items():
|
||||||
if 'histogram' in k:
|
if 'histogram' in k:
|
||||||
if rank <= 0:
|
|
||||||
tb_logger.add_histogram(k, v, current_step)
|
tb_logger.add_histogram(k, v, current_step)
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
tb_logger.add_scalars(k, v, current_step)
|
||||||
else:
|
else:
|
||||||
message += '{:s}: {:.4e} '.format(k, v)
|
message += '{:s}: {:.4e} '.format(k, v)
|
||||||
# tensorboard logger
|
# tensorboard logger
|
||||||
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
||||||
if rank <= 0:
|
|
||||||
tb_logger.add_scalar(k, v, current_step)
|
tb_logger.add_scalar(k, v, current_step)
|
||||||
if rank <= 0:
|
|
||||||
logger.info(message)
|
logger.info(message)
|
||||||
|
|
||||||
#### save models and training states
|
#### save models and training states
|
||||||
|
|
|
@ -204,7 +204,7 @@ def main():
|
||||||
_t = time()
|
_t = time()
|
||||||
|
|
||||||
#### log
|
#### log
|
||||||
if current_step % opt['logger']['print_freq'] == 0:
|
if current_step % opt['logger']['print_freq'] == 0 and rank <= 0:
|
||||||
logs = model.get_current_log(current_step)
|
logs = model.get_current_log(current_step)
|
||||||
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, current_step)
|
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(epoch, current_step)
|
||||||
for v in model.get_current_learning_rate():
|
for v in model.get_current_learning_rate():
|
||||||
|
@ -212,15 +212,14 @@ def main():
|
||||||
message += ')] '
|
message += ')] '
|
||||||
for k, v in logs.items():
|
for k, v in logs.items():
|
||||||
if 'histogram' in k:
|
if 'histogram' in k:
|
||||||
if rank <= 0:
|
|
||||||
tb_logger.add_histogram(k, v, current_step)
|
tb_logger.add_histogram(k, v, current_step)
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
tb_logger.add_scalars(k, v, current_step)
|
||||||
else:
|
else:
|
||||||
message += '{:s}: {:.4e} '.format(k, v)
|
message += '{:s}: {:.4e} '.format(k, v)
|
||||||
# tensorboard logger
|
# tensorboard logger
|
||||||
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
||||||
if rank <= 0:
|
|
||||||
tb_logger.add_scalar(k, v, current_step)
|
tb_logger.add_scalar(k, v, current_step)
|
||||||
if rank <= 0:
|
|
||||||
logger.info(message)
|
logger.info(message)
|
||||||
|
|
||||||
#### save models and training states
|
#### save models and training states
|
||||||
|
|
Loading…
Reference in New Issue
Block a user