Output histograms with SwitchedResidualGenerator

This also fixes the initialization weight for the configurable generator.
This commit is contained in:
James Betker 2020-06-16 15:54:37 -06:00
parent f8b67f134b
commit 379b96eb55
2 changed files with 29 additions and 15 deletions

View File

@ -107,6 +107,7 @@ class SwitchedResidualGenerator(nn.Module):
self.init_temperature = initial_temp
self.final_temperature_step = final_temperature_step
self.running_sum = [0, 0, 0, 0]
self.running_hist = [[],[],[],[]]
self.running_count = 0
def forward(self, x):
@ -119,16 +120,20 @@ class SwitchedResidualGenerator(nn.Module):
sw4, self.a4 = self.switch4.forward(x, True)
x = x + sw4
a1mean, _ = compute_attention_specificity(self.a1, 2)
a2mean, _ = compute_attention_specificity(self.a2, 2)
a3mean, _ = compute_attention_specificity(self.a3, 2)
a4mean, _ = compute_attention_specificity(self.a4, 2)
a1mean, a1i = compute_attention_specificity(self.a1, 2)
a2mean, a2i = compute_attention_specificity(self.a2, 2)
a3mean, a3i = compute_attention_specificity(self.a3, 2)
a4mean, a4i = compute_attention_specificity(self.a4, 2)
running_sum = [
self.running_sum[0] + a1mean,
self.running_sum[1] + a2mean,
self.running_sum[2] + a3mean,
self.running_sum[3] + a4mean,
]
self.running_hist[0].append(a1i.detach().cpu().flatten())
self.running_hist[1].append(a2i.detach().cpu().flatten())
self.running_hist[2].append(a3i.detach().cpu().flatten())
self.running_hist[3].append(a4i.detach().cpu().flatten())
self.running_count += 1
return (x,)
@ -146,14 +151,16 @@ class SwitchedResidualGenerator(nn.Module):
if step % 250 == 0:
save_attention_to_image(self.a1, 4, step, "a1")
save_attention_to_image(self.a2, 8, step, "a2")
save_attention_to_image(self.a3, 16, step, "a3", 2)
save_attention_to_image(self.a4, 32, step, "a4", 4)
save_attention_to_image(self.a2, 8, step, "a2", 2)
save_attention_to_image(self.a3, 16, step, "a3", 4)
save_attention_to_image(self.a4, 32, step, "a4", 8)
val = {"switch_temperature": temp}
for i in range(len(self.running_sum)):
val["switch_%i_specificity" % (i,)] = self.running_sum[i] / self.running_count
val["switch_%i_histogram" % (i,)] = torch.cat(self.running_hist[i])
self.running_sum[i] = 0
self.running_hist[i] = []
self.running_count = 0
return val
@ -166,12 +173,13 @@ class ConfigurableSwitchedResidualGenerator(nn.Module):
switches.append(SwitchComputer(3, filters, functools.partial(ResidualBranch, 3, 3, kernel_size=kernel, depth=layers), trans_count, sw_reduce, sw_proc, initial_temp))
initialize_weights(switches, 1)
# Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image.
initialize_weights([s.transforms for s in switches], .05)
initialize_weights([s.transforms for s in switches], .2 / len(switches))
self.switches = nn.ModuleList(switches)
self.transformation_counts = trans_counts
self.init_temperature = initial_temp
self.final_temperature_step = final_temperature_step
self.running_sum = [0 for i in range(len(switches))]
self.running_hist = [[] for i in range(len(switches))]
self.running_count = 0
def forward(self, x):
@ -179,8 +187,9 @@ class ConfigurableSwitchedResidualGenerator(nn.Module):
for i, sw in enumerate(self.switches):
x, att = sw.forward(x, True)
self.attentions.append(att)
spec, _ = compute_attention_specificity(att, 2)
spec, hist = compute_attention_specificity(att, 2)
self.running_sum[i] += spec
self.running_hist[i].append(hist.detach().cpu().flatten())
self.running_count += 1
@ -201,5 +210,7 @@ class ConfigurableSwitchedResidualGenerator(nn.Module):
for i in range(len(self.running_sum)):
val["switch_%i_specificity" % (i,)] = self.running_sum[i] / self.running_count
self.running_sum[i] = 0
val["switch_%i_histogram" % (i,)] = torch.cat(self.running_hist[i])
self.running_hist[i] = []
self.running_count = 0
return val

View File

@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs):
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_lowdim_rrdb_no_sr.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_residualgenerator.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
@ -193,11 +193,14 @@ def main():
message += '{:.3e},'.format(v)
message += ')] '
for k, v in logs.items():
message += '{:s}: {:.4e} '.format(k, v)
# tensorboard logger
if opt['use_tb_logger'] and 'debug' not in opt['name']:
if rank <= 0:
tb_logger.add_scalar(k, v, current_step)
if 'histogram' in k:
tb_logger.add_histogram(k, v, current_step)
else:
message += '{:s}: {:.4e} '.format(k, v)
# tensorboard logger
if opt['use_tb_logger'] and 'debug' not in opt['name']:
if rank <= 0:
tb_logger.add_scalar(k, v, current_step)
if rank <= 0:
logger.info(message)
#### validation