Denoise attention maps

This commit is contained in:
James Betker 2020-08-10 14:59:58 -06:00
parent 59aba1daa7
commit f0e2816239

View File

@ -531,7 +531,7 @@ class SwitchedSpsr(nn.Module):
temp = max(1, 1 + self.init_temperature * temp = max(1, 1 + self.init_temperature *
(self.final_temperature_step - step) / self.final_temperature_step) (self.final_temperature_step - step) / self.final_temperature_step)
self.set_temperature(temp) self.set_temperature(temp)
if step % 10 == 0: if step % 200 == 0:
output_path = os.path.join(experiments_path, "attention_maps", "a%i") output_path = os.path.join(experiments_path, "attention_maps", "a%i")
prefix = "attention_map_%i_%%i.png" % (step,) prefix = "attention_map_%i_%%i.png" % (step,)
[save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))] [save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))]
@ -548,8 +548,6 @@ class SwitchedSpsr(nn.Module):
return val return val
class SwitchedSpsrLr(nn.Module): class SwitchedSpsrLr(nn.Module):
def __init__(self, in_nc, out_nc, nf, upscale=4): def __init__(self, in_nc, out_nc, nf, upscale=4):
super(SwitchedSpsrLr, self).__init__() super(SwitchedSpsrLr, self).__init__()
@ -657,7 +655,7 @@ class SwitchedSpsrLr(nn.Module):
temp = max(1, 1 + self.init_temperature * temp = max(1, 1 + self.init_temperature *
(self.final_temperature_step - step) / self.final_temperature_step) (self.final_temperature_step - step) / self.final_temperature_step)
self.set_temperature(temp) self.set_temperature(temp)
if step % 10 == 0: if step % 200 == 0:
output_path = os.path.join(experiments_path, "attention_maps", "a%i") output_path = os.path.join(experiments_path, "attention_maps", "a%i")
prefix = "attention_map_%i_%%i.png" % (step,) prefix = "attention_map_%i_%%i.png" % (step,)
[save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))] [save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))]