diff --git a/codes/models/archs/StructuredSwitchedGenerator.py b/codes/models/archs/StructuredSwitchedGenerator.py index 2f83ca08..bea9035d 100644 --- a/codes/models/archs/StructuredSwitchedGenerator.py +++ b/codes/models/archs/StructuredSwitchedGenerator.py @@ -222,7 +222,7 @@ class SSGr1(nn.Module): if step % 200 == 0: output_path = os.path.join(experiments_path, "attention_maps") prefix = "amap_%i_a%i_%%i.png" - [save_attention_to_image_rgb(output_path, self.attentions[i], self.transformation_counts, prefix % (step, i), step, output_mag=False) for i in range(len(self.attentions))] + [save_attention_to_image_rgb(output_path, self.attentions[i], self.nf, prefix % (step, i), step, output_mag=False) for i in range(len(self.attentions))] torchvision.utils.save_image(self.lr, os.path.join(experiments_path, "attention_maps", "amap_%i_base_image.png" % (step,))) @@ -239,3 +239,165 @@ class SSGr1(nn.Module): val["switch_%i_histogram" % (i,)] = hists[i] return val + +class StackedSwitchGenerator(nn.Module): + def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10): + super(StackedSwitchGenerator, self).__init__() + n_upscale = int(math.log(upscale, 2)) + self.nf = nf + + # processing the input embedding + self.reference_embedding = ReferenceImageBranch(nf) + + # Feature branch + self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False) + self.sw1 = SwitchWithReference(nf, xforms, init_temperature, has_ref=False) + self.sw2 = SwitchWithReference(nf, xforms, init_temperature, has_ref=False) + self.sw3 = SwitchWithReference(nf, xforms, init_temperature, has_ref=False) + self.switches = [self.sw1.switch, self.sw2.switch, self.sw3.switch] + + self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) + self.upsample = UpconvBlock(nf, nf // 2, block=ConvGnLelu, norm=False, activation=True, bias=True) + self.final_hr_conv1 = ConvGnLelu(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True) + self.final_hr_conv2 = ConvGnLelu(nf // 2, out_nc, kernel_size=3, norm=False, activation=False, bias=False) + self.attentions = None + self.lr = None + self.init_temperature = init_temperature + self.final_temperature_step = 10000 + + def forward(self, x, ref, ref_center): + # The attention_maps debugger outputs . Save that here. + self.lr = x.detach().cpu() + + ref_code = checkpoint(self.reference_embedding, ref, ref_center) + ref_embedding = ref_code.view(-1, ref_code.shape[1], 1, 1).repeat(1, 1, x.shape[2] // 8, x.shape[3] // 8) + + x = self.model_fea_conv(x) + x1, a1 = checkpoint(self.sw1, x, ref_embedding) + x2, a2 = checkpoint(self.sw2, x1, ref_embedding) + x3, a3 = checkpoint(self.sw3, x2, ref_embedding) + x_out = checkpoint(self.final_lr_conv, x3) + x_out = checkpoint(self.upsample, x_out) + x_out = checkpoint(self.final_hr_conv2, x_out) + + self.attentions = [a1, a3, a3] + return x_out + + def set_temperature(self, temp): + [sw.set_temperature(temp) for sw in self.switches] + + def update_for_step(self, step, experiments_path='.'): + if self.attentions: + temp = max(1, 1 + self.init_temperature * + (self.final_temperature_step - step) / self.final_temperature_step) + self.set_temperature(temp) + if step % 200 == 0: + output_path = os.path.join(experiments_path, "attention_maps") + prefix = "amap_%i_a%i_%%i.png" + [save_attention_to_image_rgb(output_path, self.attentions[i], self.nf, prefix % (step, i), step, output_mag=False) for i in range(len(self.attentions))] + torchvision.utils.save_image(self.lr, os.path.join(experiments_path, "attention_maps", "amap_%i_base_image.png" % (step,))) + + + def get_debug_values(self, step, net_name): + temp = self.switches[0].switch.temperature + mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] + means = [i[0] for i in mean_hists] + hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists] + val = {"switch_temperature": temp} + for i in range(len(means)): + val["switch_%i_specificity" % (i,)] = means[i] + val["switch_%i_histogram" % (i,)] = hists[i] + return val + + +class SSGDeep(nn.Module): + def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10): + super(SSGDeep, self).__init__() + n_upscale = int(math.log(upscale, 2)) + self.nf = nf + + # processing the input embedding + self.reference_embedding = ReferenceImageBranch(nf) + + # Feature branch + self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False) + self.sw1 = SwitchWithReference(nf, xforms, init_temperature, has_ref=False) + + # Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague. + self.get_g_nopadding = ImageGradientNoPadding() + self.grad_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False, bias=False) + self.sw_grad = SwitchWithReference(nf, xforms // 2, init_temperature, has_ref=True) + self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) + self.upsample_grad = UpconvBlock(nf, nf // 2, block=ConvGnLelu, norm=False, activation=True, bias=False) + self.grad_branch_output_conv = ConvGnLelu(nf // 2, out_nc, kernel_size=1, norm=False, activation=False, bias=True) + + # Join branch (grad+fea) + self.conjoin_sw = SwitchWithReference(nf, xforms, init_temperature, has_ref=True) + self.sw3 = SwitchWithReference(nf, xforms, init_temperature, has_ref=False) + self.sw4 = SwitchWithReference(nf, xforms, init_temperature, has_ref=False) + self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) + self.upsample = UpconvBlock(nf, nf // 2, block=ConvGnLelu, norm=False, activation=True, bias=True) + self.final_hr_conv1 = ConvGnLelu(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True) + self.final_hr_conv2 = ConvGnLelu(nf // 2, out_nc, kernel_size=3, norm=False, activation=False, bias=False) + self.switches = [self.sw1.switch, self.sw_grad.switch, self.conjoin_sw.switch, self.sw3.switch, self.sw4.switch] + self.attentions = None + self.lr = None + self.init_temperature = init_temperature + self.final_temperature_step = 10000 + + def forward(self, x, ref, ref_center): + # The attention_maps debugger outputs . Save that here. + self.lr = x.detach().cpu() + + x_grad = self.get_g_nopadding(x) + ref_code = checkpoint(self.reference_embedding, ref, ref_center) + ref_embedding = ref_code.view(-1, ref_code.shape[1], 1, 1).repeat(1, 1, x.shape[2] // 8, x.shape[3] // 8) + + x = self.model_fea_conv(x) + x1, a1 = checkpoint(self.sw1, x, ref_embedding) + + x_grad = self.grad_conv(x_grad) + x_grad, a3, grad_fea_std = checkpoint(self.sw_grad, x_grad, ref_embedding, x1) + x_grad = checkpoint(self.grad_lr_conv, x_grad) + x_grad_out = checkpoint(self.upsample_grad, x_grad) + x_grad_out = checkpoint(self.grad_branch_output_conv, x_grad_out) + + x_out, a4, fea_grad_std = checkpoint(self.conjoin_sw, x1, ref_embedding, x_grad) + x_out, a5 = checkpoint(self.sw3, x_out, ref_embedding) + x_out, a6 = checkpoint(self.sw4, x_out, ref_embedding) + x_out = checkpoint(self.final_lr_conv, x_out) + x_out = checkpoint(self.upsample, x_out) + x_out = checkpoint(self.final_hr_conv2, x_out) + + self.attentions = [a1, a3, a4, a5, a6] + self.grad_fea_std = grad_fea_std.detach().cpu() + self.fea_grad_std = fea_grad_std.detach().cpu() + return x_grad_out, x_out, x_grad + + def set_temperature(self, temp): + [sw.set_temperature(temp) for sw in self.switches] + + def update_for_step(self, step, experiments_path='.'): + if self.attentions: + temp = max(1, 1 + self.init_temperature * + (self.final_temperature_step - step) / self.final_temperature_step) + self.set_temperature(temp) + if step % 200 == 0: + output_path = os.path.join(experiments_path, "attention_maps") + prefix = "amap_%i_a%i_%%i.png" + [save_attention_to_image_rgb(output_path, self.attentions[i], self.nf, prefix % (step, i), step, output_mag=False) for i in range(len(self.attentions))] + torchvision.utils.save_image(self.lr, os.path.join(experiments_path, "attention_maps", "amap_%i_base_image.png" % (step,))) + + + def get_debug_values(self, step, net_name): + temp = self.switches[0].switch.temperature + mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] + means = [i[0] for i in mean_hists] + hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists] + val = {"switch_temperature": temp, + "grad_branch_feat_intg_std_dev": self.grad_fea_std, + "conjoin_branch_grad_intg_std_dev": self.fea_grad_std} + for i in range(len(means)): + val["switch_%i_specificity" % (i,)] = means[i] + val["switch_%i_histogram" % (i,)] = hists[i] + return val \ No newline at end of file diff --git a/codes/models/networks.py b/codes/models/networks.py index 8e309a6f..823be304 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -82,13 +82,13 @@ def define_G(opt, net_key='network_G', scale=None): xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 netG = ssg.SSGr1(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) - elif which_model == 'ssg_no_embedding': + elif which_model == 'stacked_switches': xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 - netG = ssg.SSGNoEmbedding(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], + netG = ssg.StackedSwitchGenerator(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) - elif which_model == 'ssg_lite': + elif which_model == 'ssg_deep': xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 - netG = ssg.SSGLite(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], + netG = ssg.SSGDeep(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) elif which_model == "backbone_encoder": netG = SwitchedGen_arch.BackboneEncoder(pretrained_backbone=opt_net['pretrained_spinenet']) diff --git a/codes/train.py b/codes/train.py index e61abe71..2b55376d 100644 --- a/codes/train.py +++ b/codes/train.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_ssgr_constrained_gan.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_ssgr_deep.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()