Output histograms with SwitchedResidualGenerator
This also fixes the initialization weight for the configurable generator.
This commit is contained in:
parent
f8b67f134b
commit
379b96eb55
|
@ -107,6 +107,7 @@ class SwitchedResidualGenerator(nn.Module):
|
|||
self.init_temperature = initial_temp
|
||||
self.final_temperature_step = final_temperature_step
|
||||
self.running_sum = [0, 0, 0, 0]
|
||||
self.running_hist = [[],[],[],[]]
|
||||
self.running_count = 0
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -119,16 +120,20 @@ class SwitchedResidualGenerator(nn.Module):
|
|||
sw4, self.a4 = self.switch4.forward(x, True)
|
||||
x = x + sw4
|
||||
|
||||
a1mean, _ = compute_attention_specificity(self.a1, 2)
|
||||
a2mean, _ = compute_attention_specificity(self.a2, 2)
|
||||
a3mean, _ = compute_attention_specificity(self.a3, 2)
|
||||
a4mean, _ = compute_attention_specificity(self.a4, 2)
|
||||
a1mean, a1i = compute_attention_specificity(self.a1, 2)
|
||||
a2mean, a2i = compute_attention_specificity(self.a2, 2)
|
||||
a3mean, a3i = compute_attention_specificity(self.a3, 2)
|
||||
a4mean, a4i = compute_attention_specificity(self.a4, 2)
|
||||
running_sum = [
|
||||
self.running_sum[0] + a1mean,
|
||||
self.running_sum[1] + a2mean,
|
||||
self.running_sum[2] + a3mean,
|
||||
self.running_sum[3] + a4mean,
|
||||
]
|
||||
self.running_hist[0].append(a1i.detach().cpu().flatten())
|
||||
self.running_hist[1].append(a2i.detach().cpu().flatten())
|
||||
self.running_hist[2].append(a3i.detach().cpu().flatten())
|
||||
self.running_hist[3].append(a4i.detach().cpu().flatten())
|
||||
self.running_count += 1
|
||||
|
||||
return (x,)
|
||||
|
@ -146,14 +151,16 @@ class SwitchedResidualGenerator(nn.Module):
|
|||
|
||||
if step % 250 == 0:
|
||||
save_attention_to_image(self.a1, 4, step, "a1")
|
||||
save_attention_to_image(self.a2, 8, step, "a2")
|
||||
save_attention_to_image(self.a3, 16, step, "a3", 2)
|
||||
save_attention_to_image(self.a4, 32, step, "a4", 4)
|
||||
save_attention_to_image(self.a2, 8, step, "a2", 2)
|
||||
save_attention_to_image(self.a3, 16, step, "a3", 4)
|
||||
save_attention_to_image(self.a4, 32, step, "a4", 8)
|
||||
|
||||
val = {"switch_temperature": temp}
|
||||
for i in range(len(self.running_sum)):
|
||||
val["switch_%i_specificity" % (i,)] = self.running_sum[i] / self.running_count
|
||||
val["switch_%i_histogram" % (i,)] = torch.cat(self.running_hist[i])
|
||||
self.running_sum[i] = 0
|
||||
self.running_hist[i] = []
|
||||
self.running_count = 0
|
||||
return val
|
||||
|
||||
|
@ -166,12 +173,13 @@ class ConfigurableSwitchedResidualGenerator(nn.Module):
|
|||
switches.append(SwitchComputer(3, filters, functools.partial(ResidualBranch, 3, 3, kernel_size=kernel, depth=layers), trans_count, sw_reduce, sw_proc, initial_temp))
|
||||
initialize_weights(switches, 1)
|
||||
# Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image.
|
||||
initialize_weights([s.transforms for s in switches], .05)
|
||||
initialize_weights([s.transforms for s in switches], .2 / len(switches))
|
||||
self.switches = nn.ModuleList(switches)
|
||||
self.transformation_counts = trans_counts
|
||||
self.init_temperature = initial_temp
|
||||
self.final_temperature_step = final_temperature_step
|
||||
self.running_sum = [0 for i in range(len(switches))]
|
||||
self.running_hist = [[] for i in range(len(switches))]
|
||||
self.running_count = 0
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -179,8 +187,9 @@ class ConfigurableSwitchedResidualGenerator(nn.Module):
|
|||
for i, sw in enumerate(self.switches):
|
||||
x, att = sw.forward(x, True)
|
||||
self.attentions.append(att)
|
||||
spec, _ = compute_attention_specificity(att, 2)
|
||||
spec, hist = compute_attention_specificity(att, 2)
|
||||
self.running_sum[i] += spec
|
||||
self.running_hist[i].append(hist.detach().cpu().flatten())
|
||||
|
||||
self.running_count += 1
|
||||
|
||||
|
@ -201,5 +210,7 @@ class ConfigurableSwitchedResidualGenerator(nn.Module):
|
|||
for i in range(len(self.running_sum)):
|
||||
val["switch_%i_specificity" % (i,)] = self.running_sum[i] / self.running_count
|
||||
self.running_sum[i] = 0
|
||||
val["switch_%i_histogram" % (i,)] = torch.cat(self.running_hist[i])
|
||||
self.running_hist[i] = []
|
||||
self.running_count = 0
|
||||
return val
|
|
@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_lowdim_rrdb_no_sr.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_residualgenerator.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
@ -193,11 +193,14 @@ def main():
|
|||
message += '{:.3e},'.format(v)
|
||||
message += ')] '
|
||||
for k, v in logs.items():
|
||||
message += '{:s}: {:.4e} '.format(k, v)
|
||||
# tensorboard logger
|
||||
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
||||
if rank <= 0:
|
||||
tb_logger.add_scalar(k, v, current_step)
|
||||
if 'histogram' in k:
|
||||
tb_logger.add_histogram(k, v, current_step)
|
||||
else:
|
||||
message += '{:s}: {:.4e} '.format(k, v)
|
||||
# tensorboard logger
|
||||
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
||||
if rank <= 0:
|
||||
tb_logger.add_scalar(k, v, current_step)
|
||||
if rank <= 0:
|
||||
logger.info(message)
|
||||
#### validation
|
||||
|
|
Loading…
Reference in New Issue
Block a user