Update how spsr arches do attention to conform with sgsr
This commit is contained in:
parent
9a50a7966d
commit
58886109d4
|
@ -391,9 +391,10 @@ class SwitchedSpsrWithRef2(nn.Module):
|
||||||
(self.final_temperature_step - step) / self.final_temperature_step)
|
(self.final_temperature_step - step) / self.final_temperature_step)
|
||||||
self.set_temperature(temp)
|
self.set_temperature(temp)
|
||||||
if step % 200 == 0:
|
if step % 200 == 0:
|
||||||
output_path = os.path.join(experiments_path, "attention_maps", "a%i")
|
output_path = os.path.join(experiments_path, "attention_maps")
|
||||||
prefix = "attention_map_%i_%%i.png" % (step,)
|
prefix = "amap_%i_a%i_%%i.png"
|
||||||
[save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))]
|
[save_attention_to_image_rgb(output_path, self.attentions[i], self.transformation_counts, prefix % (step, i), step, output_mag=False) for i in range(len(self.attentions))]
|
||||||
|
torchvision.utils.save_image(self.lr, os.path.join(experiments_path, "attention_maps", "amap_%i_base_image.png" % (step,)))
|
||||||
|
|
||||||
def get_debug_values(self, step, net_name):
|
def get_debug_values(self, step, net_name):
|
||||||
temp = self.switches[0].switch.temperature
|
temp = self.switches[0].switch.temperature
|
||||||
|
@ -523,9 +524,10 @@ class Spsr4(nn.Module):
|
||||||
(self.final_temperature_step - step) / self.final_temperature_step)
|
(self.final_temperature_step - step) / self.final_temperature_step)
|
||||||
self.set_temperature(temp)
|
self.set_temperature(temp)
|
||||||
if step % 200 == 0:
|
if step % 200 == 0:
|
||||||
output_path = os.path.join(experiments_path, "attention_maps", "a%i")
|
output_path = os.path.join(experiments_path, "attention_maps")
|
||||||
prefix = "attention_map_%i_%%i.png" % (step,)
|
prefix = "amap_%i_a%i_%%i.png"
|
||||||
[save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))]
|
[save_attention_to_image_rgb(output_path, self.attentions[i], self.transformation_counts, prefix % (step, i), step, output_mag=False) for i in range(len(self.attentions))]
|
||||||
|
torchvision.utils.save_image(self.lr, os.path.join(experiments_path, "attention_maps", "amap_%i_base_image.png" % (step,)))
|
||||||
|
|
||||||
def get_debug_values(self, step, net_name):
|
def get_debug_values(self, step, net_name):
|
||||||
temp = self.switches[0].switch.temperature
|
temp = self.switches[0].switch.temperature
|
||||||
|
@ -654,9 +656,10 @@ class Spsr5(nn.Module):
|
||||||
(self.final_temperature_step - step) / self.final_temperature_step)
|
(self.final_temperature_step - step) / self.final_temperature_step)
|
||||||
self.set_temperature(temp)
|
self.set_temperature(temp)
|
||||||
if step % 200 == 0:
|
if step % 200 == 0:
|
||||||
output_path = os.path.join(experiments_path, "attention_maps", "a%i")
|
output_path = os.path.join(experiments_path, "attention_maps")
|
||||||
prefix = "attention_map_%i_%%i.png" % (step,)
|
prefix = "amap_%i_a%i_%%i.png"
|
||||||
[save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))]
|
[save_attention_to_image_rgb(output_path, self.attentions[i], self.transformation_counts, prefix % (step, i), step, output_mag=False) for i in range(len(self.attentions))]
|
||||||
|
torchvision.utils.save_image(self.lr, os.path.join(experiments_path, "attention_maps", "amap_%i_base_image.png" % (step,)))
|
||||||
|
|
||||||
def get_debug_values(self, step, net_name):
|
def get_debug_values(self, step, net_name):
|
||||||
temp = self.switches[0].switch.temperature
|
temp = self.switches[0].switch.temperature
|
||||||
|
|
Loading…
Reference in New Issue
Block a user