Fix attention images
This commit is contained in:
parent
7e777ea34c
commit
fe50d6f9d0
|
@ -557,10 +557,10 @@ class SwitchModelBase(nn.Module):
|
||||||
temp = max(1, 1 + self.init_temperature *
|
temp = max(1, 1 + self.init_temperature *
|
||||||
(self.final_temperature_step - step) / self.final_temperature_step)
|
(self.final_temperature_step - step) / self.final_temperature_step)
|
||||||
self.set_temperature(temp)
|
self.set_temperature(temp)
|
||||||
if step % 200 == 0:
|
if step % 100 == 0:
|
||||||
output_path = os.path.join(experiments_path, "attention_maps")
|
output_path = os.path.join(experiments_path, "attention_maps")
|
||||||
prefix = "amap_%i_a%i_%%i.png"
|
prefix = "amap_%i_a%i_%%i.png"
|
||||||
[save_attention_to_image_rgb(output_path, self.attentions[i], self.nf, prefix % (step, i), step,
|
[save_attention_to_image_rgb(output_path, self.attentions[i], self.attentions[i].shape[3], prefix % (step, i), step,
|
||||||
output_mag=False) for i in range(len(self.attentions))]
|
output_mag=False) for i in range(len(self.attentions))]
|
||||||
if self.lr is not None:
|
if self.lr is not None:
|
||||||
torchvision.utils.save_image(self.lr[:, :3], os.path.join(experiments_path, "attention_maps",
|
torchvision.utils.save_image(self.lr[:, :3], os.path.join(experiments_path, "attention_maps",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user