forked from mrq/DL-Art-School
Update how attention_maps are created
This commit is contained in:
parent
c139f5cd17
commit
f33ed578a2
|
@ -7,7 +7,8 @@ from collections import OrderedDict
|
||||||
from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock
|
from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock
|
||||||
from models.archs.RRDBNet_arch import ResidualDenseBlock_5C, RRDB
|
from models.archs.RRDBNet_arch import ResidualDenseBlock_5C, RRDB
|
||||||
from models.archs.spinenet_arch import SpineNet
|
from models.archs.spinenet_arch import SpineNet
|
||||||
from switched_conv_util import save_attention_to_image
|
from switched_conv_util import save_attention_to_image_rgb
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
class MultiConvBlock(nn.Module):
|
class MultiConvBlock(nn.Module):
|
||||||
|
@ -228,7 +229,9 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
||||||
temp = 1 / temp
|
temp = 1 / temp
|
||||||
self.set_temperature(temp)
|
self.set_temperature(temp)
|
||||||
if step % 50 == 0:
|
if step % 50 == 0:
|
||||||
[save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts, step, "a%i" % (i+1,), l_mult=10) for i in range(len(self.attentions))]
|
output_path = os.path.join(experiments_path, "attention_maps", "a%i")
|
||||||
|
prefix = "attention_map_%i_%%i.png" % (step,)
|
||||||
|
[save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))]
|
||||||
|
|
||||||
def get_debug_values(self, step):
|
def get_debug_values(self, step):
|
||||||
temp = self.switches[0].switch.temperature
|
temp = self.switches[0].switch.temperature
|
||||||
|
@ -335,7 +338,9 @@ class ConfigurableSwitchedResidualGenerator4(nn.Module):
|
||||||
temp = 1 / temp
|
temp = 1 / temp
|
||||||
self.set_temperature(temp)
|
self.set_temperature(temp)
|
||||||
if step % 50 == 0:
|
if step % 50 == 0:
|
||||||
[save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts, step, "a%i" % (i+1,), l_mult=10) for i in range(len(self.attentions))]
|
output_path = os.path.join(experiments_path, "attention_maps", "a%i")
|
||||||
|
prefix = "attention_map_%i_%%i.png" % (step,)
|
||||||
|
[save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))]
|
||||||
|
|
||||||
def get_debug_values(self, step):
|
def get_debug_values(self, step):
|
||||||
temp = self.switches[0].switch.temperature
|
temp = self.switches[0].switch.temperature
|
||||||
|
@ -426,7 +431,9 @@ class ConfigurableSwitchedResidualGenerator3(nn.Module):
|
||||||
temp = 1 / temp
|
temp = 1 / temp
|
||||||
self.set_temperature(temp)
|
self.set_temperature(temp)
|
||||||
if step % 50 == 0:
|
if step % 50 == 0:
|
||||||
save_attention_to_image(experiments_path, self.attentions[0], self.transformation_counts, step, "a%i" % (1,), l_mult=10)
|
output_path = os.path.join(experiments_path, "attention_maps", "a%i")
|
||||||
|
prefix = "attention_map_%i_%%i.png" % (step,)
|
||||||
|
[save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))]
|
||||||
|
|
||||||
def get_debug_values(self, step):
|
def get_debug_values(self, step):
|
||||||
temp = self.switch.switch.temperature
|
temp = self.switch.switch.temperature
|
||||||
|
|
Loading…
Reference in New Issue
Block a user