diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index fb5296a4..09b5b60d 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -107,6 +107,7 @@ class SwitchedResidualGenerator(nn.Module): self.init_temperature = initial_temp self.final_temperature_step = final_temperature_step self.running_sum = [0, 0, 0, 0] + self.running_hist = [[],[],[],[]] self.running_count = 0 def forward(self, x): @@ -119,16 +120,20 @@ class SwitchedResidualGenerator(nn.Module): sw4, self.a4 = self.switch4.forward(x, True) x = x + sw4 - a1mean, _ = compute_attention_specificity(self.a1, 2) - a2mean, _ = compute_attention_specificity(self.a2, 2) - a3mean, _ = compute_attention_specificity(self.a3, 2) - a4mean, _ = compute_attention_specificity(self.a4, 2) + a1mean, a1i = compute_attention_specificity(self.a1, 2) + a2mean, a2i = compute_attention_specificity(self.a2, 2) + a3mean, a3i = compute_attention_specificity(self.a3, 2) + a4mean, a4i = compute_attention_specificity(self.a4, 2) running_sum = [ self.running_sum[0] + a1mean, self.running_sum[1] + a2mean, self.running_sum[2] + a3mean, self.running_sum[3] + a4mean, ] + self.running_hist[0].append(a1i.detach().cpu().flatten()) + self.running_hist[1].append(a2i.detach().cpu().flatten()) + self.running_hist[2].append(a3i.detach().cpu().flatten()) + self.running_hist[3].append(a4i.detach().cpu().flatten()) self.running_count += 1 return (x,) @@ -146,14 +151,16 @@ class SwitchedResidualGenerator(nn.Module): if step % 250 == 0: save_attention_to_image(self.a1, 4, step, "a1") - save_attention_to_image(self.a2, 8, step, "a2") - save_attention_to_image(self.a3, 16, step, "a3", 2) - save_attention_to_image(self.a4, 32, step, "a4", 4) + save_attention_to_image(self.a2, 8, step, "a2", 2) + save_attention_to_image(self.a3, 16, step, "a3", 4) + save_attention_to_image(self.a4, 32, step, "a4", 8) val = {"switch_temperature": temp} for i in range(len(self.running_sum)): val["switch_%i_specificity" % (i,)] = self.running_sum[i] / self.running_count + val["switch_%i_histogram" % (i,)] = torch.cat(self.running_hist[i]) self.running_sum[i] = 0 + self.running_hist[i] = [] self.running_count = 0 return val @@ -166,12 +173,13 @@ class ConfigurableSwitchedResidualGenerator(nn.Module): switches.append(SwitchComputer(3, filters, functools.partial(ResidualBranch, 3, 3, kernel_size=kernel, depth=layers), trans_count, sw_reduce, sw_proc, initial_temp)) initialize_weights(switches, 1) # Initialize the transforms with a lesser weight, since they are repeatedly added on to the resultant image. - initialize_weights([s.transforms for s in switches], .05) + initialize_weights([s.transforms for s in switches], .2 / len(switches)) self.switches = nn.ModuleList(switches) self.transformation_counts = trans_counts self.init_temperature = initial_temp self.final_temperature_step = final_temperature_step self.running_sum = [0 for i in range(len(switches))] + self.running_hist = [[] for i in range(len(switches))] self.running_count = 0 def forward(self, x): @@ -179,8 +187,9 @@ class ConfigurableSwitchedResidualGenerator(nn.Module): for i, sw in enumerate(self.switches): x, att = sw.forward(x, True) self.attentions.append(att) - spec, _ = compute_attention_specificity(att, 2) + spec, hist = compute_attention_specificity(att, 2) self.running_sum[i] += spec + self.running_hist[i].append(hist.detach().cpu().flatten()) self.running_count += 1 @@ -201,5 +210,7 @@ class ConfigurableSwitchedResidualGenerator(nn.Module): for i in range(len(self.running_sum)): val["switch_%i_specificity" % (i,)] = self.running_sum[i] / self.running_count self.running_sum[i] = 0 + val["switch_%i_histogram" % (i,)] = torch.cat(self.running_hist[i]) + self.running_hist[i] = [] self.running_count = 0 return val \ No newline at end of file diff --git a/codes/train.py b/codes/train.py index d0ee97ff..b2763c25 100644 --- a/codes/train.py +++ b/codes/train.py @@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_lowdim_rrdb_no_sr.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_residualgenerator.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) @@ -193,11 +193,14 @@ def main(): message += '{:.3e},'.format(v) message += ')] ' for k, v in logs.items(): - message += '{:s}: {:.4e} '.format(k, v) - # tensorboard logger - if opt['use_tb_logger'] and 'debug' not in opt['name']: - if rank <= 0: - tb_logger.add_scalar(k, v, current_step) + if 'histogram' in k: + tb_logger.add_histogram(k, v, current_step) + else: + message += '{:s}: {:.4e} '.format(k, v) + # tensorboard logger + if opt['use_tb_logger'] and 'debug' not in opt['name']: + if rank <= 0: + tb_logger.add_scalar(k, v, current_step) if rank <= 0: logger.info(message) #### validation