Support attention deferral in deep ssgr
This commit is contained in:
parent
840927063a
commit
4111942ada
|
@ -352,10 +352,15 @@ class SSGDeep(nn.Module):
|
|||
self.init_temperature = init_temperature
|
||||
self.final_temperature_step = 10000
|
||||
|
||||
def forward(self, x, ref, ref_center):
|
||||
def forward(self, x, ref, ref_center, save_attentions=True):
|
||||
# The attention_maps debugger outputs <x>. Save that here.
|
||||
self.lr = x.detach().cpu()
|
||||
|
||||
# If we're not saving attention, we also shouldn't be updating the attention norm. This is because the attention
|
||||
# norm should only be getting updates with new data, not recurrent generator sampling.
|
||||
for sw in self.switches:
|
||||
sw.set_update_attention_norm(save_attentions)
|
||||
|
||||
x_grad = self.get_g_nopadding(x)
|
||||
ref_code = checkpoint(self.reference_embedding, ref, ref_center)
|
||||
ref_embedding = ref_code.view(-1, ref_code.shape[1], 1, 1).repeat(1, 1, x.shape[2] // 8, x.shape[3] // 8)
|
||||
|
@ -376,7 +381,8 @@ class SSGDeep(nn.Module):
|
|||
x_out = checkpoint(self.upsample, x_out)
|
||||
x_out = checkpoint(self.final_hr_conv2, x_out)
|
||||
|
||||
self.attentions = [a1, a3, a4, a5, a6]
|
||||
if save_attentions:
|
||||
self.attentions = [a1, a3, a4, a5, a6]
|
||||
self.grad_fea_std = grad_fea_std.detach().cpu()
|
||||
self.fea_grad_std = fea_grad_std.detach().cpu()
|
||||
return x_grad_out, x_out, x_grad
|
||||
|
|
Loading…
Reference in New Issue
Block a user