diff --git a/codes/models/archs/StructuredSwitchedGenerator.py b/codes/models/archs/StructuredSwitchedGenerator.py index 0be355dd..bd0b9c3d 100644 --- a/codes/models/archs/StructuredSwitchedGenerator.py +++ b/codes/models/archs/StructuredSwitchedGenerator.py @@ -352,10 +352,15 @@ class SSGDeep(nn.Module): self.init_temperature = init_temperature self.final_temperature_step = 10000 - def forward(self, x, ref, ref_center): + def forward(self, x, ref, ref_center, save_attentions=True): # The attention_maps debugger outputs . Save that here. self.lr = x.detach().cpu() + # If we're not saving attention, we also shouldn't be updating the attention norm. This is because the attention + # norm should only be getting updates with new data, not recurrent generator sampling. + for sw in self.switches: + sw.set_update_attention_norm(save_attentions) + x_grad = self.get_g_nopadding(x) ref_code = checkpoint(self.reference_embedding, ref, ref_center) ref_embedding = ref_code.view(-1, ref_code.shape[1], 1, 1).repeat(1, 1, x.shape[2] // 8, x.shape[3] // 8) @@ -376,7 +381,8 @@ class SSGDeep(nn.Module): x_out = checkpoint(self.upsample, x_out) x_out = checkpoint(self.final_hr_conv2, x_out) - self.attentions = [a1, a3, a4, a5, a6] + if save_attentions: + self.attentions = [a1, a3, a4, a5, a6] self.grad_fea_std = grad_fea_std.detach().cpu() self.fea_grad_std = fea_grad_std.detach().cpu() return x_grad_out, x_out, x_grad