Don't compute attention statistics on multiple generator invocations of the same data
This commit is contained in:
parent
e760658fdb
commit
51044929af
|
@ -184,10 +184,15 @@ class SSGr1(nn.Module):
|
||||||
self.init_temperature = init_temperature
|
self.init_temperature = init_temperature
|
||||||
self.final_temperature_step = 10000
|
self.final_temperature_step = 10000
|
||||||
|
|
||||||
def forward(self, x, ref, ref_center):
|
def forward(self, x, ref, ref_center, save_attentions=True):
|
||||||
# The attention_maps debugger outputs <x>. Save that here.
|
# The attention_maps debugger outputs <x>. Save that here.
|
||||||
self.lr = x.detach().cpu()
|
self.lr = x.detach().cpu()
|
||||||
|
|
||||||
|
# If we're not saving attention, we also shouldn't be updating the attention norm. This is because the attention
|
||||||
|
# norm should only be getting updates with new data, not recurrent generator sampling.
|
||||||
|
for sw in self.switches:
|
||||||
|
sw.set_update_attention_norm(save_attentions)
|
||||||
|
|
||||||
x_grad = self.get_g_nopadding(x)
|
x_grad = self.get_g_nopadding(x)
|
||||||
ref_code = checkpoint(self.reference_embedding, ref, ref_center)
|
ref_code = checkpoint(self.reference_embedding, ref, ref_center)
|
||||||
ref_embedding = ref_code.view(-1, ref_code.shape[1], 1, 1).repeat(1, 1, x.shape[2] // 8, x.shape[3] // 8)
|
ref_embedding = ref_code.view(-1, ref_code.shape[1], 1, 1).repeat(1, 1, x.shape[2] // 8, x.shape[3] // 8)
|
||||||
|
@ -206,7 +211,8 @@ class SSGr1(nn.Module):
|
||||||
x_out = checkpoint(self.upsample, x_out)
|
x_out = checkpoint(self.upsample, x_out)
|
||||||
x_out = checkpoint(self.final_hr_conv2, x_out)
|
x_out = checkpoint(self.final_hr_conv2, x_out)
|
||||||
|
|
||||||
self.attentions = [a1, a3, a4]
|
if save_attentions:
|
||||||
|
self.attentions = [a1, a3, a4]
|
||||||
self.grad_fea_std = grad_fea_std.detach().cpu()
|
self.grad_fea_std = grad_fea_std.detach().cpu()
|
||||||
self.fea_grad_std = fea_grad_std.detach().cpu()
|
self.fea_grad_std = fea_grad_std.detach().cpu()
|
||||||
return x_grad_out, x_out, x_grad
|
return x_grad_out, x_out, x_grad
|
||||||
|
@ -265,7 +271,7 @@ class StackedSwitchGenerator(nn.Module):
|
||||||
self.init_temperature = init_temperature
|
self.init_temperature = init_temperature
|
||||||
self.final_temperature_step = 10000
|
self.final_temperature_step = 10000
|
||||||
|
|
||||||
def forward(self, x, ref, ref_center):
|
def forward(self, x, ref, ref_center, save_attentions=True):
|
||||||
# The attention_maps debugger outputs <x>. Save that here.
|
# The attention_maps debugger outputs <x>. Save that here.
|
||||||
self.lr = x.detach().cpu()
|
self.lr = x.detach().cpu()
|
||||||
|
|
||||||
|
@ -280,7 +286,8 @@ class StackedSwitchGenerator(nn.Module):
|
||||||
x_out = checkpoint(self.upsample, x_out)
|
x_out = checkpoint(self.upsample, x_out)
|
||||||
x_out = checkpoint(self.final_hr_conv2, x_out)
|
x_out = checkpoint(self.final_hr_conv2, x_out)
|
||||||
|
|
||||||
self.attentions = [a1, a3, a3]
|
if save_attentions:
|
||||||
|
self.attentions = [a1, a3, a3]
|
||||||
return x_out,
|
return x_out,
|
||||||
|
|
||||||
def set_temperature(self, temp):
|
def set_temperature(self, temp):
|
||||||
|
|
|
@ -104,7 +104,10 @@ class ConfigurableSwitchComputer(nn.Module):
|
||||||
# The post_switch_conv gets a low scale initially. The network can decide to magnify it (or not)
|
# The post_switch_conv gets a low scale initially. The network can decide to magnify it (or not)
|
||||||
# depending on its needs.
|
# depending on its needs.
|
||||||
self.psc_scale = nn.Parameter(torch.full((1,), float(.1)))
|
self.psc_scale = nn.Parameter(torch.full((1,), float(.1)))
|
||||||
|
self.update_norm = True
|
||||||
|
|
||||||
|
def set_update_attention_norm(self, set_val):
|
||||||
|
self.update_norm = set_val
|
||||||
|
|
||||||
# Regarding inputs: it is acceptable to pass in a tuple/list as an input for (x), but the first element
|
# Regarding inputs: it is acceptable to pass in a tuple/list as an input for (x), but the first element
|
||||||
# *must* be the actual parameter that gets fed through the network - it is assumed to be the identity.
|
# *must* be the actual parameter that gets fed through the network - it is assumed to be the identity.
|
||||||
|
@ -148,7 +151,7 @@ class ConfigurableSwitchComputer(nn.Module):
|
||||||
m = self.multiplexer(*att_in)
|
m = self.multiplexer(*att_in)
|
||||||
|
|
||||||
# It is assumed that [xformed] and [m] are collapsed into tensors at this point.
|
# It is assumed that [xformed] and [m] are collapsed into tensors at this point.
|
||||||
outputs, attention = self.switch(xformed, m, True)
|
outputs, attention = self.switch(xformed, m, True, self.update_norm)
|
||||||
outputs = identity + outputs * self.switch_scale * fixed_scale
|
outputs = identity + outputs * self.switch_scale * fixed_scale
|
||||||
outputs = outputs + self.post_switch_conv(outputs) * self.psc_scale * fixed_scale
|
outputs = outputs + self.post_switch_conv(outputs) * self.psc_scale * fixed_scale
|
||||||
if output_attention_weights:
|
if output_attention_weights:
|
||||||
|
|
|
@ -4,6 +4,7 @@ from data.weight_scheduler import get_scheduler_for_opt
|
||||||
from utils.util import checkpoint
|
from utils.util import checkpoint
|
||||||
import torchvision.utils as utils
|
import torchvision.utils as utils
|
||||||
#from models.steps.recursive_gen_injectors import ImageFlowInjector
|
#from models.steps.recursive_gen_injectors import ImageFlowInjector
|
||||||
|
from models.steps.losses import extract_params_from_state
|
||||||
|
|
||||||
# Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions.
|
# Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions.
|
||||||
def create_injector(opt_inject, env):
|
def create_injector(opt_inject, env):
|
||||||
|
@ -43,7 +44,6 @@ class Injector(torch.nn.Module):
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
# Uses a generator to synthesize an image from [in] and injects the results into [out]
|
# Uses a generator to synthesize an image from [in] and injects the results into [out]
|
||||||
# Note that results are *not* detached.
|
# Note that results are *not* detached.
|
||||||
class ImageGeneratorInjector(Injector):
|
class ImageGeneratorInjector(Injector):
|
||||||
|
@ -53,7 +53,7 @@ class ImageGeneratorInjector(Injector):
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
gen = self.env['generators'][self.opt['generator']]
|
gen = self.env['generators'][self.opt['generator']]
|
||||||
if isinstance(self.input, list):
|
if isinstance(self.input, list):
|
||||||
params = [state[i] for i in self.input]
|
params = extract_params_from_state(self.input, state)
|
||||||
results = gen(*params)
|
results = gen(*params)
|
||||||
else:
|
else:
|
||||||
results = gen(state[self.input])
|
results = gen(state[self.input])
|
||||||
|
|
Loading…
Reference in New Issue
Block a user