diff --git a/codes/models/archs/StructuredSwitchedGenerator.py b/codes/models/archs/StructuredSwitchedGenerator.py index 5b9857c5..0be355dd 100644 --- a/codes/models/archs/StructuredSwitchedGenerator.py +++ b/codes/models/archs/StructuredSwitchedGenerator.py @@ -184,10 +184,15 @@ class SSGr1(nn.Module): self.init_temperature = init_temperature self.final_temperature_step = 10000 - def forward(self, x, ref, ref_center): + def forward(self, x, ref, ref_center, save_attentions=True): # The attention_maps debugger outputs . Save that here. self.lr = x.detach().cpu() + # If we're not saving attention, we also shouldn't be updating the attention norm. This is because the attention + # norm should only be getting updates with new data, not recurrent generator sampling. + for sw in self.switches: + sw.set_update_attention_norm(save_attentions) + x_grad = self.get_g_nopadding(x) ref_code = checkpoint(self.reference_embedding, ref, ref_center) ref_embedding = ref_code.view(-1, ref_code.shape[1], 1, 1).repeat(1, 1, x.shape[2] // 8, x.shape[3] // 8) @@ -206,7 +211,8 @@ class SSGr1(nn.Module): x_out = checkpoint(self.upsample, x_out) x_out = checkpoint(self.final_hr_conv2, x_out) - self.attentions = [a1, a3, a4] + if save_attentions: + self.attentions = [a1, a3, a4] self.grad_fea_std = grad_fea_std.detach().cpu() self.fea_grad_std = fea_grad_std.detach().cpu() return x_grad_out, x_out, x_grad @@ -265,7 +271,7 @@ class StackedSwitchGenerator(nn.Module): self.init_temperature = init_temperature self.final_temperature_step = 10000 - def forward(self, x, ref, ref_center): + def forward(self, x, ref, ref_center, save_attentions=True): # The attention_maps debugger outputs . Save that here. self.lr = x.detach().cpu() @@ -280,7 +286,8 @@ class StackedSwitchGenerator(nn.Module): x_out = checkpoint(self.upsample, x_out) x_out = checkpoint(self.final_hr_conv2, x_out) - self.attentions = [a1, a3, a3] + if save_attentions: + self.attentions = [a1, a3, a3] return x_out, def set_temperature(self, temp): diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 7b6edfc8..644beeb1 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -104,7 +104,10 @@ class ConfigurableSwitchComputer(nn.Module): # The post_switch_conv gets a low scale initially. The network can decide to magnify it (or not) # depending on its needs. self.psc_scale = nn.Parameter(torch.full((1,), float(.1))) + self.update_norm = True + def set_update_attention_norm(self, set_val): + self.update_norm = set_val # Regarding inputs: it is acceptable to pass in a tuple/list as an input for (x), but the first element # *must* be the actual parameter that gets fed through the network - it is assumed to be the identity. @@ -148,7 +151,7 @@ class ConfigurableSwitchComputer(nn.Module): m = self.multiplexer(*att_in) # It is assumed that [xformed] and [m] are collapsed into tensors at this point. - outputs, attention = self.switch(xformed, m, True) + outputs, attention = self.switch(xformed, m, True, self.update_norm) outputs = identity + outputs * self.switch_scale * fixed_scale outputs = outputs + self.post_switch_conv(outputs) * self.psc_scale * fixed_scale if output_attention_weights: diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index f54a81f6..6e870f92 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -4,6 +4,7 @@ from data.weight_scheduler import get_scheduler_for_opt from utils.util import checkpoint import torchvision.utils as utils #from models.steps.recursive_gen_injectors import ImageFlowInjector +from models.steps.losses import extract_params_from_state # Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions. def create_injector(opt_inject, env): @@ -43,7 +44,6 @@ class Injector(torch.nn.Module): def forward(self, state): raise NotImplementedError - # Uses a generator to synthesize an image from [in] and injects the results into [out] # Note that results are *not* detached. class ImageGeneratorInjector(Injector): @@ -53,7 +53,7 @@ class ImageGeneratorInjector(Injector): def forward(self, state): gen = self.env['generators'][self.opt['generator']] if isinstance(self.input, list): - params = [state[i] for i in self.input] + params = extract_params_from_state(self.input, state) results = gen(*params) else: results = gen(state[self.input])