forked from mrq/DL-Art-School
Output histograms with SwitchedResidualGenerator
This also fixes the initialization weight for the configurable generator.
This commit is contained in:
parent
f8b67f134b
commit
379b96eb55
|
@ -107,6 +107,7 @@ class SwitchedResidualGenerator(nn.Module):
|
||||||
self.init_temperature = initial_temp
|
self.init_temperature = initial_temp
|
||||||
self.final_temperature_step = final_temperature_step
|
self.final_temperature_step = final_temperature_step
|
||||||
self.running_sum = [0, 0, 0, 0]
|
self.running_sum = [0, 0, 0, 0]
|
||||||
|
self.running_hist = [[],[],[],[]]
|
||||||
self.running_count = 0
|
self.running_count = 0
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -119,16 +120,20 @@ class SwitchedResidualGenerator(nn.Module):
|
||||||
sw4, self.a4 = self.switch4.forward(x, True)
|
sw4, self.a4 = self.switch4.forward(x, True)
|
||||||
x = x + sw4
|
x = x + sw4
|
||||||
|
|
||||||
a1mean, _ = compute_attention_specificity(self.a1, 2)
|
a1mean, a1i = compute_attention_specificity(self.a1, 2)
|
||||||
a2mean, _ = compute_attention_specificity(self.a2, 2)
|
a2mean, a2i = compute_attention_specificity(self.a2, 2)
|
||||||
a3mean, _ = compute_attention_specificity(self.a3, 2)
|
a3mean, a3i = compute_attention_specificity(self.a3, 2)
|
||||||
a4mean, _ = compute_attention_specificity(self.a4, 2)
|
a4mean, a4i = compute_attention_specificity(self.a4, 2)
|
||||||
running_sum = [
|
running_sum = [
|
||||||
self.running_sum[0] + a1mean,
|
self.running_sum[0] + a1mean,
|
||||||
self.running_sum[1] + a2mean,
|
self.running_sum[1] + a2mean,
|
||||||
self.running_sum[2] + a3mean,
|
self.running_sum[2] + a3mean,
|
||||||
self.running_sum[3] + a4mean,
|
self.running_sum[3] + a4mean,
|
||||||
]
|
]
|
||||||
|
self.running_hist[0].append(a1i.detach().cpu().flatten())
|
||||||
|
self.running_hist[1].append(a2i.detach().cpu().flatten())
|
||||||
|
self.running_hist[2].append(a3i.detach().cpu().flatten())
|
||||||
|
self.running_hist[3].append(a4i.detach().cpu().flatten())
|
||||||
self.running_count += 1
|
self.running_count += 1
|
||||||
|
|
||||||
return (x,)
|
return (x,)
|
||||||
|
@ -146,14 +151,16 @@ class SwitchedResidualGenerator(nn.Module):
|
||||||
|
|
||||||
if step % 250 == 0:
|
if step % 250 == 0:
|
||||||
save_attention_to_image(self.a1, 4, step, "a1")
|
save_attention_to_image(self.a1, 4, step, "a1")
|
||||||
save_attention_to_image(self.a2, 8, step, "a2")
|
save_attention_to_image(self.a2, 8, step, "a2", 2)
|
||||||
save_attention_to_image(self.a3, 16, step, "a3", 2)
|
save_attention_to_image(self.a3, 16, step, "a3", 4)
|
||||||
save_attention_to_image(self.a4, 32, step, "a4", 4)
|
save_attention_to_image(self.a4, 32, step, "a4", 8)
|
||||||
|
|
||||||
val = {"switch_temperature": temp}
|
val = {"switch_temperature": temp}
|
||||||
for i in range(len(self.running_sum)):
|
for i in range(len(self.running_sum)):
|
||||||
val["switch_%i_specificity" % (i,)] = self.running_sum[i] / self.running_count
|
val["switch_%i_specificity" % (i,)] = self.running_sum[i] / self.running_count
|
||||||
|
val["switch_%i_histogram" % (i,)] = torch.cat(self.running_hist[i])
|
||||||
self.running_sum[i] = 0
|
self.running_sum[i] = 0
|
||||||
|
self.running_hist[i] = []
|
||||||
self.running_count = 0
|
self.running_count = 0
|
||||||
return val
|
return val
|
||||||
|
|
||||||
|
@ -166,12 +173,13 @@ class ConfigurableSwitchedResidualGenerator(nn.Module):
|
||||||
switches.append(SwitchComputer(3, filters, functools.partial(ResidualBranch, 3, 3, kernel_size=kernel, depth=layers), trans_count, sw_reduce, sw_proc, initial_temp))
|
switches.append(SwitchComputer(3, filters, functools.partial(ResidualBranch, 3, 3, kernel_size=kernel, depth=layers), trans_count, sw_reduce, sw_proc, initial_temp))
|
||||||
initialize_weights(switches, 1)
|
initialize_weights(switches, 1)
|
||||||
# Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image.
|
# Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image.
|
||||||
initialize_weights([s.transforms for s in switches], .05)
|
initialize_weights([s.transforms for s in switches], .2 / len(switches))
|
||||||
self.switches = nn.ModuleList(switches)
|
self.switches = nn.ModuleList(switches)
|
||||||
self.transformation_counts = trans_counts
|
self.transformation_counts = trans_counts
|
||||||
self.init_temperature = initial_temp
|
self.init_temperature = initial_temp
|
||||||
self.final_temperature_step = final_temperature_step
|
self.final_temperature_step = final_temperature_step
|
||||||
self.running_sum = [0 for i in range(len(switches))]
|
self.running_sum = [0 for i in range(len(switches))]
|
||||||
|
self.running_hist = [[] for i in range(len(switches))]
|
||||||
self.running_count = 0
|
self.running_count = 0
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -179,8 +187,9 @@ class ConfigurableSwitchedResidualGenerator(nn.Module):
|
||||||
for i, sw in enumerate(self.switches):
|
for i, sw in enumerate(self.switches):
|
||||||
x, att = sw.forward(x, True)
|
x, att = sw.forward(x, True)
|
||||||
self.attentions.append(att)
|
self.attentions.append(att)
|
||||||
spec, _ = compute_attention_specificity(att, 2)
|
spec, hist = compute_attention_specificity(att, 2)
|
||||||
self.running_sum[i] += spec
|
self.running_sum[i] += spec
|
||||||
|
self.running_hist[i].append(hist.detach().cpu().flatten())
|
||||||
|
|
||||||
self.running_count += 1
|
self.running_count += 1
|
||||||
|
|
||||||
|
@ -201,5 +210,7 @@ class ConfigurableSwitchedResidualGenerator(nn.Module):
|
||||||
for i in range(len(self.running_sum)):
|
for i in range(len(self.running_sum)):
|
||||||
val["switch_%i_specificity" % (i,)] = self.running_sum[i] / self.running_count
|
val["switch_%i_specificity" % (i,)] = self.running_sum[i] / self.running_count
|
||||||
self.running_sum[i] = 0
|
self.running_sum[i] = 0
|
||||||
|
val["switch_%i_histogram" % (i,)] = torch.cat(self.running_hist[i])
|
||||||
|
self.running_hist[i] = []
|
||||||
self.running_count = 0
|
self.running_count = 0
|
||||||
return val
|
return val
|
|
@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs):
|
||||||
def main():
|
def main():
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_lowdim_rrdb_no_sr.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_residualgenerator.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
@ -193,11 +193,14 @@ def main():
|
||||||
message += '{:.3e},'.format(v)
|
message += '{:.3e},'.format(v)
|
||||||
message += ')] '
|
message += ')] '
|
||||||
for k, v in logs.items():
|
for k, v in logs.items():
|
||||||
message += '{:s}: {:.4e} '.format(k, v)
|
if 'histogram' in k:
|
||||||
# tensorboard logger
|
tb_logger.add_histogram(k, v, current_step)
|
||||||
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
else:
|
||||||
if rank <= 0:
|
message += '{:s}: {:.4e} '.format(k, v)
|
||||||
tb_logger.add_scalar(k, v, current_step)
|
# tensorboard logger
|
||||||
|
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
||||||
|
if rank <= 0:
|
||||||
|
tb_logger.add_scalar(k, v, current_step)
|
||||||
if rank <= 0:
|
if rank <= 0:
|
||||||
logger.info(message)
|
logger.info(message)
|
||||||
#### validation
|
#### validation
|
||||||
|
|
Loading…
Reference in New Issue
Block a user