More fixes

This commit is contained in:
James Betker 2020-09-24 17:51:52 -06:00
parent 553917a8d1
commit ea565b7eaf

View File

@ -164,6 +164,9 @@ class SPSRNet(nn.Module):
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
# The attention_maps debugger outputs <x>. Save that here.
self.lr = x.detach().cpu()
x_grad = self.get_g_nopadding(x) x_grad = self.get_g_nopadding(x)
b, f, w, h = x.shape b, f, w, h = x.shape
@ -332,6 +335,9 @@ class SwitchedSpsrWithRef2(nn.Module):
self.final_temperature_step = 10000 self.final_temperature_step = 10000
def forward(self, x, ref, center_coord): def forward(self, x, ref, center_coord):
# The attention_maps debugger outputs <x>. Save that here.
self.lr = x.detach().cpu()
ref_stds = [] ref_stds = []
noise_stds = [] noise_stds = []
@ -391,7 +397,7 @@ class SwitchedSpsrWithRef2(nn.Module):
temp = max(1, 1 + self.init_temperature * temp = max(1, 1 + self.init_temperature *
(self.final_temperature_step - step) / self.final_temperature_step) (self.final_temperature_step - step) / self.final_temperature_step)
self.set_temperature(temp) self.set_temperature(temp)
if step % 200 == 0: if step % 500 == 0:
output_path = os.path.join(experiments_path, "attention_maps") output_path = os.path.join(experiments_path, "attention_maps")
prefix = "amap_%i_a%i_%%i.png" prefix = "amap_%i_a%i_%%i.png"
[save_attention_to_image_rgb(output_path, self.attentions[i], self.transformation_counts, prefix % (step, i), step, output_mag=False) for i in range(len(self.attentions))] [save_attention_to_image_rgb(output_path, self.attentions[i], self.transformation_counts, prefix % (step, i), step, output_mag=False) for i in range(len(self.attentions))]
@ -476,6 +482,9 @@ class Spsr4(nn.Module):
self.final_temperature_step = 10000 self.final_temperature_step = 10000
def forward(self, x, embedding): def forward(self, x, embedding):
# The attention_maps debugger outputs <x>. Save that here.
self.lr = x.detach().cpu()
noise_stds = [] noise_stds = []
x_grad = self.get_g_nopadding(x) x_grad = self.get_g_nopadding(x)
@ -524,7 +533,7 @@ class Spsr4(nn.Module):
temp = max(1, 1 + self.init_temperature * temp = max(1, 1 + self.init_temperature *
(self.final_temperature_step - step) / self.final_temperature_step) (self.final_temperature_step - step) / self.final_temperature_step)
self.set_temperature(temp) self.set_temperature(temp)
if step % 200 == 0: if step % 500 == 0:
output_path = os.path.join(experiments_path, "attention_maps") output_path = os.path.join(experiments_path, "attention_maps")
prefix = "amap_%i_a%i_%%i.png" prefix = "amap_%i_a%i_%%i.png"
[save_attention_to_image_rgb(output_path, self.attentions[i], self.transformation_counts, prefix % (step, i), step, output_mag=False) for i in range(len(self.attentions))] [save_attention_to_image_rgb(output_path, self.attentions[i], self.transformation_counts, prefix % (step, i), step, output_mag=False) for i in range(len(self.attentions))]
@ -606,8 +615,12 @@ class Spsr5(nn.Module):
self.attentions = None self.attentions = None
self.init_temperature = init_temperature self.init_temperature = init_temperature
self.final_temperature_step = 10000 self.final_temperature_step = 10000
self.lr = None
def forward(self, x, embedding): def forward(self, x, embedding):
# The attention_maps debugger outputs <x>. Save that here.
self.lr = x.detach().cpu()
noise_stds = [] noise_stds = []
x_grad = self.get_g_nopadding(x) x_grad = self.get_g_nopadding(x)
@ -656,7 +669,7 @@ class Spsr5(nn.Module):
temp = max(1, 1 + self.init_temperature * temp = max(1, 1 + self.init_temperature *
(self.final_temperature_step - step) / self.final_temperature_step) (self.final_temperature_step - step) / self.final_temperature_step)
self.set_temperature(temp) self.set_temperature(temp)
if step % 200 == 0: if step % 500 == 0:
output_path = os.path.join(experiments_path, "attention_maps") output_path = os.path.join(experiments_path, "attention_maps")
prefix = "amap_%i_a%i_%%i.png" prefix = "amap_%i_a%i_%%i.png"
[save_attention_to_image_rgb(output_path, self.attentions[i], self.transformation_counts, prefix % (step, i), step, output_mag=False) for i in range(len(self.attentions))] [save_attention_to_image_rgb(output_path, self.attentions[i], self.transformation_counts, prefix % (step, i), step, output_mag=False) for i in range(len(self.attentions))]