diff --git a/codes/models/archs/ChainedEmbeddingGen.py b/codes/models/archs/ChainedEmbeddingGen.py index 90a9e426..31ea8c90 100644 --- a/codes/models/archs/ChainedEmbeddingGen.py +++ b/codes/models/archs/ChainedEmbeddingGen.py @@ -52,150 +52,6 @@ class BasicEmbeddingPyramid(nn.Module): return x, p -class ChainedEmbeddingGen(nn.Module): - def __init__(self, depth=10, in_nc=3): - super(ChainedEmbeddingGen, self).__init__() - self.initial_conv = ConvGnLelu(in_nc, 64, kernel_size=7, bias=True, norm=False, activation=False) - self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False) - self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(depth)]) - self.upsample = FinalUpsampleBlock2x(64, out_nc=in_nc) - - def forward(self, x): - fea = self.initial_conv(x) - emb = checkpoint(self.spine, fea) - for block in self.blocks: - fea = fea + checkpoint(block, fea, *emb)[0] - return checkpoint(self.upsample, fea), - - -class ChainedEmbeddingGenWithStructure(nn.Module): - def __init__(self, in_nc=3, depth=10, recurrent=False, recurrent_nf=3, recurrent_stride=2): - super(ChainedEmbeddingGenWithStructure, self).__init__() - self.recurrent = recurrent - self.initial_conv = ConvGnLelu(in_nc, 64, kernel_size=7, bias=True, norm=False, activation=False) - if recurrent: - self.recurrent_nf = recurrent_nf - self.recurrent_stride = recurrent_stride - self.recurrent_process = ConvGnLelu(recurrent_nf, 64, kernel_size=3, stride=recurrent_stride, norm=False, bias=True, activation=False) - self.recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False) - self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False) - self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(depth)]) - self.structure_joins = nn.ModuleList([ConjoinBlock(64) for i in range(3)]) - self.structure_blocks = nn.ModuleList([ConvGnLelu(64, 64, kernel_size=3, bias=False, norm=False, activation=False, weight_init_factor=.1) for i in range(3)]) - self.structure_upsample = FinalUpsampleBlock2x(64) - self.grad_extract = ImageGradientNoPadding() - self.upsample = FinalUpsampleBlock2x(64) - self.ref_join_std = 0 - - def forward(self, x, recurrent=None): - fea = self.initial_conv(x) - if self.recurrent: - if recurrent is None: - if self.recurrent_nf == 3: - recurrent = torch.zeros_like(x) - if self.recurrent_stride != 1: - recurrent = torch.nn.functional.interpolate(recurrent, scale_factor=self.recurrent_stride, mode='nearest') - else: - recurrent = torch.zeros_like(fea) - rec = self.recurrent_process(recurrent) - fea, recstd = self.recurrent_join(fea, rec) - self.ref_join_std = recstd.item() - emb = checkpoint(self.spine, fea) - grad = fea - for i, block in enumerate(self.blocks): - fea = fea + checkpoint(block, fea, *emb)[0] - if i < 3: - structure_br = checkpoint(self.structure_joins[i], grad, fea) - grad = grad + checkpoint(self.structure_blocks[i], structure_br) - out = checkpoint(self.upsample, fea) - return out, self.grad_extract(checkpoint(self.structure_upsample, grad)), self.grad_extract(out), fea - - def get_debug_values(self, step, net_name): - return { 'ref_join_std': self.ref_join_std } - - -# This is a structural block that learns to mute regions of a residual transformation given a signal. -class OptionalPassthroughBlock(nn.Module): - def __init__(self, nf, initial_bias=10): - super(OptionalPassthroughBlock, self).__init__() - self.switch_process = nn.Sequential(ConvGnLelu(nf, nf // 2, 1, activation=False, norm=False, bias=False), - ConvGnLelu(nf // 2, nf // 4, 1, activation=False, norm=False, bias=False), - ConvGnLelu(nf // 4, 1, 1, activation=False, norm=False, bias=False)) - self.bias = nn.Parameter(torch.tensor(initial_bias, dtype=torch.float), requires_grad=True) - self.activation = nn.Sigmoid() - - def forward(self, x, switch_signal): - switch = self.switch_process(switch_signal) - bypass_map = self.activation(self.bias + switch) - return x * bypass_map, bypass_map - - -class StructuredChainedEmbeddingGenWithBypass(nn.Module): - def __init__(self, depth=10, recurrent=False, recurrent_nf=3, recurrent_stride=2, bypass_bias=10): - super(StructuredChainedEmbeddingGenWithBypass, self).__init__() - self.recurrent = recurrent - self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False) - if recurrent: - self.recurrent_nf = recurrent_nf - self.recurrent_stride = recurrent_stride - self.recurrent_process = ConvGnLelu(recurrent_nf, 64, kernel_size=3, stride=recurrent_stride, norm=False, bias=True, activation=False) - self.recurrent_join = ReferenceJoinBlock(64, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False) - self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False) - self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(depth)]) - self.bypasses = nn.ModuleList([OptionalPassthroughBlock(64, initial_bias=bypass_bias) for i in range(depth)]) - self.structure_joins = nn.ModuleList([ConjoinBlock(64) for i in range(3)]) - self.structure_blocks = nn.ModuleList([ConvGnLelu(64, 64, kernel_size=3, bias=False, norm=False, activation=False, weight_init_factor=.1) for i in range(3)]) - self.structure_upsample = FinalUpsampleBlock2x(64) - self.grad_extract = ImageGradientNoPadding() - self.upsample = FinalUpsampleBlock2x(64) - self.ref_join_std = 0 - self.block_residual_means = [0 for _ in range(depth)] - self.block_residual_stds = [0 for _ in range(depth)] - self.bypass_maps = [] - - def forward(self, x, recurrent=None): - fea = self.initial_conv(x) - if self.recurrent: - if recurrent is None: - if self.recurrent_nf == 3: - recurrent = torch.zeros_like(x) - if self.recurrent_stride != 1: - recurrent = torch.nn.functional.interpolate(recurrent, scale_factor=self.recurrent_stride, mode='nearest') - else: - recurrent = torch.zeros_like(fea) - rec = self.recurrent_process(recurrent) - fea, recstd = self.recurrent_join(fea, rec) - self.ref_join_std = recstd.item() - emb = checkpoint(self.spine, fea) - grad = fea - self.bypass_maps = [] - for i, block in enumerate(self.blocks): - residual, context = checkpoint(block, fea, *emb) - residual, bypass_map = checkpoint(self.bypasses[i], residual, context) - fea = fea + residual - self.bypass_maps.append(bypass_map.detach()) - self.block_residual_means[i] = residual.mean().item() - self.block_residual_stds[i] = residual.std().item() - if i < 3: - structure_br = checkpoint(self.structure_joins[i], grad, fea) - grad = grad + checkpoint(self.structure_blocks[i], structure_br) - out = checkpoint(self.upsample, fea) - return out, self.grad_extract(checkpoint(self.structure_upsample, grad)), self.grad_extract(out), fea - - def visual_dbg(self, step, path): - for i, bm in enumerate(self.bypass_maps): - torchvision.utils.save_image(bm.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1))) - - def get_debug_values(self, step, net_name): - biases = [b.bias.item() for b in self.bypasses] - blk_stds, blk_means = {}, {} - for i, (s, m) in enumerate(zip(self.block_residual_stds, self.block_residual_means)): - blk_stds['block_%i' % (i+1,)] = s - blk_means['block_%i' % (i+1,)] = m - return {'ref_join_std': self.ref_join_std, 'bypass_biases': sum(biases) / len(biases), - 'blocks_std': blk_stds, 'blocks_mean': blk_means} - - class MultifacetedChainedEmbeddingGen(nn.Module): def __init__(self, depth=10, scale=2): super(MultifacetedChainedEmbeddingGen, self).__init__() diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index ab8277e7..1937b0b5 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -198,263 +198,6 @@ class SPSRNetSimplified(nn.Module): ######### return x_out_branch, x_out, x_grad -class Spsr5(nn.Module): - def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, multiplexer_reductions=2, init_temperature=10): - super(Spsr5, self).__init__() - n_upscale = int(math.log(upscale, 2)) - - # switch options - transformation_filters = nf - self.transformation_counts = xforms - multiplx_fn = functools.partial(QueryKeyMultiplexer, transformation_filters, reductions=multiplexer_reductions) - pretransform_fn = functools.partial(ConvGnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1) - transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), - transformation_filters, kernel_size=3, depth=3, - weight_init_factor=.1) - - # Feature branch - self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False) - self.noise_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.1) - self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, - pre_transform_block=pretransform_fn, transform_block=transform_fn, - attention_norm=True, - transform_count=self.transformation_counts, init_temp=init_temperature, - add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True) - self.sw2 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, - pre_transform_block=pretransform_fn, transform_block=transform_fn, - attention_norm=True, - transform_count=self.transformation_counts, init_temp=init_temperature, - add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True) - - # Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague. - self.get_g_nopadding = ImageGradientNoPadding() - self.grad_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False) - self.noise_ref_join_grad = ReferenceJoinBlock(nf, residual_weight_init_factor=.1) - self.grad_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3, final_norm=False) - self.sw_grad = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, - pre_transform_block=pretransform_fn, transform_block=transform_fn, - attention_norm=True, - transform_count=self.transformation_counts // 2, init_temp=init_temperature, - add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True) - self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) - self.grad_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) - self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=False) for _ in range(n_upscale)]) - self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=True) - - # Join branch (grad+fea) - self.noise_ref_join_conjoin = ReferenceJoinBlock(nf, residual_weight_init_factor=.1) - self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3) - self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, - pre_transform_block=pretransform_fn, transform_block=transform_fn, - attention_norm=True, - transform_count=self.transformation_counts, init_temp=init_temperature, - add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True) - self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) - self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=True) for _ in range(n_upscale)]) - self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=True) - self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=False) - self.switches = [self.sw1, self.sw2, self.sw_grad, self.conjoin_sw] - self.attentions = None - self.init_temperature = init_temperature - self.final_temperature_step = 10000 - self.lr = None - - def forward(self, x, embedding): - # The attention_maps debugger outputs . Save that here. - self.lr = x.detach().cpu() - - noise_stds = [] - - x_grad = self.get_g_nopadding(x) - - x = self.model_fea_conv(x) - x1 = x - x1, a1 = self.sw1(x1, identity=x, att_in=(x1, embedding)) - - x2 = x1 - x2, nstd = self.noise_ref_join(x2, torch.randn_like(x2)) - x2, a2 = self.sw2(x2, identity=x1, att_in=(x2, embedding)) - noise_stds.append(nstd) - - x_grad = self.grad_conv(x_grad) - x_grad_identity = x_grad - x_grad, nstd = self.noise_ref_join_grad(x_grad, torch.randn_like(x_grad)) - x_grad, grad_fea_std = self.grad_ref_join(x_grad, x1) - x_grad, a3 = self.sw_grad(x_grad, identity=x_grad_identity, att_in=(x_grad, embedding)) - x_grad = self.grad_lr_conv(x_grad) - x_grad = self.grad_lr_conv2(x_grad) - x_grad_out = self.upsample_grad(x_grad) - x_grad_out = self.grad_branch_output_conv(x_grad_out) - noise_stds.append(nstd) - - x_out = x2 - x_out, nstd = self.noise_ref_join_conjoin(x_out, torch.randn_like(x_out)) - x_out, fea_grad_std = self.conjoin_ref_join(x_out, x_grad) - x_out, a4 = self.conjoin_sw(x_out, identity=x2, att_in=(x_out, embedding)) - x_out = self.final_lr_conv(x_out) - x_out = self.upsample(x_out) - x_out = self.final_hr_conv1(x_out) - x_out = self.final_hr_conv2(x_out) - noise_stds.append(nstd) - - self.attentions = [a1, a2, a3, a4] - self.noise_stds = torch.stack(noise_stds).mean().detach().cpu() - self.grad_fea_std = grad_fea_std.detach().cpu() - self.fea_grad_std = fea_grad_std.detach().cpu() - return x_grad_out, x_out, x_grad - - def set_temperature(self, temp): - [sw.set_temperature(temp) for sw in self.switches] - - def update_for_step(self, step, experiments_path='.'): - if self.attentions: - temp = max(1, 1 + self.init_temperature * - (self.final_temperature_step - step) / self.final_temperature_step) - self.set_temperature(temp) - if step % 500 == 0: - output_path = os.path.join(experiments_path, "attention_maps") - prefix = "amap_%i_a%i_%%i.png" - [save_attention_to_image_rgb(output_path, self.attentions[i], self.transformation_counts, prefix % (step, i), step, output_mag=False) for i in range(len(self.attentions))] - torchvision.utils.save_image(self.lr, os.path.join(experiments_path, "attention_maps", "amap_%i_base_image.png" % (step,))) - - def get_debug_values(self, step, net_name): - temp = self.switches[0].switch.temperature - mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] - means = [i[0] for i in mean_hists] - hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists] - val = {"switch_temperature": temp, - "noise_branch_std_dev": self.noise_stds, - "grad_branch_feat_intg_std_dev": self.grad_fea_std, - "conjoin_branch_grad_intg_std_dev": self.fea_grad_std} - for i in range(len(means)): - val["switch_%i_specificity" % (i,)] = means[i] - val["switch_%i_histogram" % (i,)] = hists[i] - return val - - -# Variant of Spsr5 which uses multiplexer blocks that are not derived from an embedding. Also makes a few "best practices" -# adjustments learned over the past few weeks (no noise, kernel_size=7 -class Spsr6(nn.Module): - def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, multiplexer_reductions=3, init_temperature=10): - super(Spsr6, self).__init__() - n_upscale = int(math.log(upscale, 2)) - - # switch options - transformation_filters = nf - self.transformation_counts = xforms - multiplx_fn = functools.partial(QueryKeyPyramidMultiplexer, transformation_filters, reductions=multiplexer_reductions) - transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), - transformation_filters, kernel_size=3, depth=3, - weight_init_factor=.1) - - # Feature branch - self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False) - self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, - pre_transform_block=None, transform_block=transform_fn, - attention_norm=True, - transform_count=self.transformation_counts, init_temp=init_temperature, - add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True) - self.sw2 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, - pre_transform_block=None, transform_block=transform_fn, - attention_norm=True, - transform_count=self.transformation_counts, init_temp=init_temperature, - add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True) - self.feature_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False) - self.feature_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False) - - # Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague. - self.get_g_nopadding = ImageGradientNoPadding() - self.grad_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False) - self.grad_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3, final_norm=False) - self.sw_grad = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, - pre_transform_block=None, transform_block=transform_fn, - attention_norm=True, - transform_count=self.transformation_counts // 2, init_temp=init_temperature, - add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True) - self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) - self.grad_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=1, norm=False, activation=True, bias=True) - self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=False) for _ in range(n_upscale)]) - self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=True) - - # Join branch (grad+fea) - self.noise_ref_join_conjoin = ReferenceJoinBlock(nf, residual_weight_init_factor=.1) - self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3) - self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, - pre_transform_block=None, transform_block=transform_fn, - attention_norm=True, - transform_count=self.transformation_counts, init_temp=init_temperature, - add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True) - self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) - self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=True) for _ in range(n_upscale)]) - self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=True) - self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=False) - self.switches = [self.sw1, self.sw2, self.sw_grad, self.conjoin_sw] - self.attentions = None - self.init_temperature = init_temperature - self.final_temperature_step = 10000 - self.lr = None - - def forward(self, x): - # The attention_maps debugger outputs . Save that here. - self.lr = x.detach().cpu() - - x_grad = self.get_g_nopadding(x) - - x = self.model_fea_conv(x) - x1 = x - x1, a1 = self.sw1(x1, identity=x) - - x2 = x1 - x2, a2 = self.sw2(x2, identity=x1) - - x_grad = self.grad_conv(x_grad) - x_grad_identity = x_grad - x_grad, grad_fea_std = self.grad_ref_join(x_grad, x1) - x_grad, a3 = self.sw_grad(x_grad, identity=x_grad_identity) - x_grad = self.grad_lr_conv(x_grad) - x_grad = self.grad_lr_conv2(x_grad) - x_grad_out = self.upsample_grad(x_grad) - x_grad_out = self.grad_branch_output_conv(x_grad_out) - - x_out = x2 - x_out, fea_grad_std = self.conjoin_ref_join(x_out, x_grad) - x_out, a4 = self.conjoin_sw(x_out, identity=x2) - x_out = self.final_lr_conv(x_out) - x_out = checkpoint(self.upsample, x_out) - x_out = checkpoint(self.final_hr_conv1, x_out) - x_out = self.final_hr_conv2(x_out) - - self.attentions = [a1, a2, a3, a4] - self.grad_fea_std = grad_fea_std.detach().cpu() - self.fea_grad_std = fea_grad_std.detach().cpu() - return x_grad_out, x_out, x_grad - - def set_temperature(self, temp): - [sw.set_temperature(temp) for sw in self.switches] - - def update_for_step(self, step, experiments_path='.'): - if self.attentions: - temp = max(1, 1 + self.init_temperature * - (self.final_temperature_step - step) / self.final_temperature_step) - self.set_temperature(temp) - if step % 500 == 0: - output_path = os.path.join(experiments_path, "attention_maps") - prefix = "amap_%i_a%i_%%i.png" - [save_attention_to_image_rgb(output_path, self.attentions[i], self.transformation_counts, prefix % (step, i), step, output_mag=False) for i in range(len(self.attentions))] - torchvision.utils.save_image(self.lr, os.path.join(experiments_path, "attention_maps", "amap_%i_base_image.png" % (step,))) - - def get_debug_values(self, step, net_name): - temp = self.switches[0].switch.temperature - mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] - means = [i[0] for i in mean_hists] - hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists] - val = {"switch_temperature": temp, - "grad_branch_feat_intg_std_dev": self.grad_fea_std, - "conjoin_branch_grad_intg_std_dev": self.fea_grad_std} - for i in range(len(means)): - val["switch_%i_specificity" % (i,)] = means[i] - val["switch_%i_histogram" % (i,)] = hists[i] - return val # Variant of Spsr6 which uses multiplexer blocks that feed off of a reference embedding. Also computes that embedding. class Spsr7(nn.Module): @@ -623,109 +366,6 @@ class AttentionBlock(nn.Module): return self.switch(x, identity=x, att_in=(x, mplex_ref)) -# SPSR7 with incremental improvements and also using the new AttentionBlock to save gpu memory. -class Spsr9(nn.Module): - def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, multiplexer_reductions=3, init_temperature=10): - super(Spsr9, self).__init__() - n_upscale = int(math.log(upscale, 2)) - self.nf = nf - self.transformation_counts = xforms - - # processing the input embedding - self.reference_embedding = ReferenceImageBranch(nf) - - - # Feature branch - self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False) - self.sw1 = AttentionBlock(nf, self.transformation_counts, multiplexer_reductions, init_temperature, False) - self.sw2 = AttentionBlock(nf, self.transformation_counts, multiplexer_reductions, init_temperature, False) - - # Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague. - self.get_g_nopadding = ImageGradientNoPadding() - self.grad_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False, bias=False) - self.sw_grad = AttentionBlock(nf, self.transformation_counts // 2, multiplexer_reductions, init_temperature, True) - self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) - self.grad_lr_conv2 = ConvGnLelu(nf, nf, kernel_size=1, norm=False, activation=True, bias=True) - self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=False) for _ in range(n_upscale)]) - self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=True) - - # Join branch (grad+fea) - self.conjoin_sw = AttentionBlock(nf, self.transformation_counts, multiplexer_reductions, init_temperature, True) - self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) - self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=True) for _ in range(n_upscale)]) - self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=True) - self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=False) - self.switches = [self.sw1.switch, self.sw2.switch, self.sw_grad.switch, self.conjoin_sw.switch] - self.attentions = None - self.init_temperature = init_temperature - self.final_temperature_step = 10000 - self.lr = None - - def forward(self, x, ref, ref_center, update_attention_norm=True): - # The attention_maps debugger outputs . Save that here. - self.lr = x.detach().cpu() - - for sw in self.switches: - sw.set_update_attention_norm(update_attention_norm) - - x_grad = self.get_g_nopadding(x) - ref_code = checkpoint(self.reference_embedding, ref, ref_center) - ref_embedding = ref_code.view(-1, self.nf * 8, 1, 1).repeat(1, 1, x.shape[2] // 8, x.shape[3] // 8) - - x = self.model_fea_conv(x) - x1 = x - x1, a1 = checkpoint(self.sw1, x1, ref_embedding) - x2 = x1 - x2, a2 = checkpoint(self.sw2, x2, ref_embedding) - - x_grad = self.grad_conv(x_grad) - x_grad_identity = x_grad - x_grad, a3, grad_fea_std = checkpoint(self.sw_grad, x_grad, ref_embedding, x1) - x_grad = self.grad_lr_conv(x_grad) - x_grad = self.grad_lr_conv2(x_grad) - x_grad_out = checkpoint(self.upsample_grad, x_grad) - x_grad_out = checkpoint(self.grad_branch_output_conv, x_grad_out) - - x_out = x2 - x_out, a4, fea_grad_std = checkpoint(self.conjoin_sw, x_out, ref_embedding, x_grad) - x_out = self.final_lr_conv(x_out) - x_out = checkpoint(self.upsample, x_out) - x_out = checkpoint(self.final_hr_conv1, x_out) - x_out = checkpoint(self.final_hr_conv2, x_out) - - self.attentions = [a1, a2, a3, a4] - self.grad_fea_std = grad_fea_std.detach().cpu() - self.fea_grad_std = fea_grad_std.detach().cpu() - return x_grad_out, x_out - - def set_temperature(self, temp): - [sw.set_temperature(temp) for sw in self.switches] - - def update_for_step(self, step, experiments_path='.'): - if self.attentions: - temp = max(1, 1 + self.init_temperature * - (self.final_temperature_step - step) / self.final_temperature_step) - self.set_temperature(temp) - if step % 500 == 0: - output_path = os.path.join(experiments_path, "attention_maps") - prefix = "amap_%i_a%i_%%i.png" - [save_attention_to_image_rgb(output_path, self.attentions[i], self.transformation_counts, prefix % (step, i), step, output_mag=False) for i in range(len(self.attentions))] - torchvision.utils.save_image(self.lr, os.path.join(experiments_path, "attention_maps", "amap_%i_base_image.png" % (step,))) - - def get_debug_values(self, step, net_name): - temp = self.switches[0].switch.temperature - mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] - means = [i[0] for i in mean_hists] - hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists] - val = {"switch_temperature": temp, - "grad_branch_feat_intg_std_dev": self.grad_fea_std, - "conjoin_branch_grad_intg_std_dev": self.fea_grad_std} - for i in range(len(means)): - val["switch_%i_specificity" % (i,)] = means[i] - val["switch_%i_histogram" % (i,)] = hists[i] - return val - - class SwitchedSpsr(nn.Module): def __init__(self, in_nc, nf, xforms=8, upscale=4, init_temperature=10): super(SwitchedSpsr, self).__init__() diff --git a/codes/models/archs/StructuredSwitchedGenerator.py b/codes/models/archs/StructuredSwitchedGenerator.py deleted file mode 100644 index a677b8b0..00000000 --- a/codes/models/archs/StructuredSwitchedGenerator.py +++ /dev/null @@ -1,576 +0,0 @@ -import functools -import math - -import torch -import torch.nn.functional as F -from torch import nn - -from models.archs.SPSR_arch import ImageGradientNoPadding -from models.archs.SwitchedResidualGenerator_arch import ConfigurableSwitchComputer, gather_2d, SwitchModelBase -from models.archs.arch_util import MultiConvBlock, ConvGnLelu, ConvGnSilu, ReferenceJoinBlock -from utils.util import checkpoint - - -# VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation -# Doubles the input filter count. -class HalvingProcessingBlock(nn.Module): - def __init__(self, filters, factor=2): - super(HalvingProcessingBlock, self).__init__() - self.bnconv1 = ConvGnSilu(filters, filters, norm=False, bias=False) - self.bnconv2 = ConvGnSilu(filters, int(filters * factor), kernel_size=1, stride=2, norm=True, bias=False) - - def forward(self, x): - x = self.bnconv1(x) - return self.bnconv2(x) - - -class ExpansionBlock2(nn.Module): - def __init__(self, filters_in, filters_out=None, block=ConvGnSilu, factor=2): - super(ExpansionBlock2, self).__init__() - if filters_out is None: - filters_out = int(filters_in / factor) - self.decimate = block(filters_in, filters_out, kernel_size=1, bias=False, activation=True, norm=False) - self.process_passthrough = block(filters_out, filters_out, kernel_size=3, bias=True, activation=True, norm=False) - self.conjoin = block(filters_out*2, filters_out*2, kernel_size=1, bias=False, activation=True, norm=False) - self.reduce = block(filters_out*2, filters_out, kernel_size=1, bias=False, activation=False, norm=True) - - # input is the feature signal with shape (b, f, w, h) - # passthrough is the structure signal with shape (b, f/2, w*2, h*2) - # output is conjoined upsample with shape (b, f/2, w*2, h*2) - def forward(self, input, passthrough): - x = F.interpolate(input, scale_factor=2, mode="nearest") - x = self.decimate(x) - p = self.process_passthrough(passthrough) - x = self.conjoin(torch.cat([x, p], dim=1)) - return self.reduce(x) - - -# Basic convolutional upsampling block that uses interpolate. -class UpconvBlock(nn.Module): - def __init__(self, filters_in, filters_out=None, block=ConvGnSilu, norm=True, activation=True, bias=False): - super(UpconvBlock, self).__init__() - self.reduce = block(filters_in, filters_out, kernel_size=1, bias=False, activation=False, norm=False) - self.process = block(filters_out, filters_out, kernel_size=3, bias=bias, activation=activation, norm=norm) - - def forward(self, x): - x = self.reduce(x) - x = F.interpolate(x, scale_factor=2, mode="nearest") - return self.process(x) - - -class QueryKeyMultiplexer(nn.Module): - def __init__(self, nf, multiplexer_channels, embedding_channels=216, reductions=3): - super(QueryKeyMultiplexer, self).__init__() - - # Blocks used to create the query - self.input_process = ConvGnSilu(nf, nf, activation=True, norm=False, bias=True) - self.embedding_process = ConvGnSilu(embedding_channels, 128, kernel_size=1, activation=True, norm=False, bias=True) - self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(int(nf * 1.5 ** i), factor=1.5) for i in range(reductions)]) - reduction_filters = int(nf * 1.5 ** reductions) - self.processing_blocks = nn.Sequential( - ConvGnSilu(reduction_filters + 128, reduction_filters + 64, kernel_size=1, activation=True, norm=False, bias=True), - ConvGnSilu(reduction_filters + 64, reduction_filters, kernel_size=1, activation=True, norm=False, bias=False), - ConvGnSilu(reduction_filters, reduction_filters, kernel_size=3, activation=True, norm=True, bias=False), - ConvGnSilu(reduction_filters, reduction_filters, kernel_size=3, activation=True, norm=True, bias=False)) - self.expansion_blocks = nn.ModuleList([ExpansionBlock2(int(reduction_filters // (1.5 ** i)), factor=1.5) for i in range(reductions)]) - - # Blocks used to create the key - self.key_process = ConvGnSilu(nf, nf, kernel_size=1, activation=True, norm=False, bias=False) - - # Postprocessing blocks. - self.query_key_combine = ConvGnSilu(nf*2, nf, kernel_size=1, activation=True, norm=False, bias=False) - self.cbl1 = ConvGnSilu(nf, nf // 2, kernel_size=1, norm=True, bias=False, num_groups=4) - self.cbl2 = ConvGnSilu(nf // 2, 1, kernel_size=1, norm=False, bias=False) - - def forward(self, x, embedding, transformations): - q = self.input_process(x) - embedding = self.embedding_process(embedding) - reduction_identities = [] - for b in self.reduction_blocks: - reduction_identities.append(q) - q = b(q) - q = self.processing_blocks(torch.cat([q, embedding], dim=1)) - for i, b in enumerate(self.expansion_blocks): - q = b(q, reduction_identities[-i - 1]) - - b, t, f, h, w = transformations.shape - k = transformations.view(b * t, f, h, w) - k = self.key_process(k) - - q = q.view(b, 1, f, h, w).repeat(1, t, 1, 1, 1).view(b * t, f, h, w) - v = self.query_key_combine(torch.cat([q, k], dim=1)) - - v = self.cbl1(v) - v = self.cbl2(v) - - return v.view(b, t, h, w) - - -# Computes a linear latent by performing processing on the reference image and returning the filters of a single point, -# which should be centered on the image patch being processed. -# -# Output is base_filters * 1.5^3. -class ReferenceImageBranch(nn.Module): - def __init__(self, base_filters=64): - super(ReferenceImageBranch, self).__init__() - final_filters = int(base_filters*1.5**3) - self.features = nn.Sequential(ConvGnSilu(4, base_filters, kernel_size=7, bias=True), - HalvingProcessingBlock(base_filters, factor=1.5), - HalvingProcessingBlock(int(base_filters*1.5), factor=1.5), - HalvingProcessingBlock(int(base_filters*1.5**2), factor=1.5), - ConvGnSilu(final_filters, final_filters, activation=True, norm=True, bias=False)) - - # center_point is a [b,2] long tensor describing the center point of where the patch was taken from the reference - # image. - def forward(self, x, center_point): - x = self.features(x) - return gather_2d(x, center_point // 8) # Divide by 8 to scale the center_point down. - -class SwitchWithReference(nn.Module): - def __init__(self, nf, num_transforms, init_temperature=10, has_ref=True): - super(SwitchWithReference, self).__init__() - self.nf = nf - self.transformation_counts = num_transforms - multiplx_fn = functools.partial(QueryKeyMultiplexer, nf) - transform_fn = functools.partial(MultiConvBlock, nf, int(nf * 1.25), nf, kernel_size=3, depth=4, weight_init_factor=.1) - if has_ref: - self.ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3, final_norm=False, kernel_size=1, depth=2) - else: - self.ref_join = None - self.switch = ConfigurableSwitchComputer(nf, multiplx_fn, - pre_transform_block=None, transform_block=transform_fn, - attention_norm=True, - transform_count=self.transformation_counts, init_temp=init_temperature, - add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True) - - def forward(self, x, mplex_ref=None, ref=None): - if self.ref_join is not None: - branch, ref_std = self.ref_join(x, ref) - return self.switch(branch, identity=x, att_in=(branch, mplex_ref)) + (ref_std,) - else: - return self.switch(x, identity=x, att_in=(x, mplex_ref)) - - -class SSGr1(SwitchModelBase): - def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10, recurrent=False): - super(SSGr1, self).__init__(init_temperature, 10000) - n_upscale = int(math.log(upscale, 2)) - self.nf = nf - - if recurrent: - self.recurrent = True - self.recurrent_process = ConvGnLelu(in_nc, nf, kernel_size=3, stride=2, norm=False, bias=True, activation=False) - self.recurrent_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False) - else: - self.recurrent = False - - # processing the input embedding - self.reference_embedding = ReferenceImageBranch(nf) - - # Feature branch - self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False) - self.sw1 = SwitchWithReference(nf, xforms, init_temperature, has_ref=False) - - # Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague. - self.get_g_nopadding = ImageGradientNoPadding() - self.grad_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False) - self.sw_grad = SwitchWithReference(nf, xforms // 2, init_temperature, has_ref=True) - self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) - self.upsample_grad = UpconvBlock(nf, nf // 2, block=ConvGnLelu, norm=False, activation=True, bias=False) - self.grad_branch_output_conv = ConvGnLelu(nf // 2, out_nc, kernel_size=1, norm=False, activation=False, bias=True) - - # Join branch (grad+fea) - self.conjoin_sw = SwitchWithReference(nf, xforms, init_temperature, has_ref=True) - self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) - self.upsample = UpconvBlock(nf, nf // 2, block=ConvGnLelu, norm=False, activation=True, bias=True) - self.final_hr_conv1 = ConvGnLelu(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True) - self.final_hr_conv2 = ConvGnLelu(nf // 2, out_nc, kernel_size=3, norm=False, activation=False, bias=False) - self.switches = [self.sw1.switch, self.sw_grad.switch, self.conjoin_sw.switch] - - def forward(self, x, ref, ref_center, save_attentions=True, recurrent=None): - # The attention_maps debugger outputs . Save that here. - self.lr = x.detach().cpu() - - # If we're not saving attention, we also shouldn't be updating the attention norm. This is because the attention - # norm should only be getting updates with new data, not recurrent generator sampling. - for sw in self.switches: - sw.set_update_attention_norm(save_attentions) - - x_grad = self.get_g_nopadding(x) - ref_code = checkpoint(self.reference_embedding, ref, ref_center) - ref_embedding = ref_code.view(-1, ref_code.shape[1], 1, 1).repeat(1, 1, x.shape[2] // 8, x.shape[3] // 8) - - x = self.model_fea_conv(x) - if self.recurrent: - rec = self.recurrent_process(recurrent) - x, recurrent_join_std = self.recurrent_join(x, rec) - else: - recurrent_join_std = 0 - x1, a1 = checkpoint(self.sw1, x, ref_embedding) - - x_grad = self.grad_conv(x_grad) - x_grad, a3, grad_fea_std = checkpoint(self.sw_grad, x_grad, ref_embedding, x1) - x_grad = checkpoint(self.grad_lr_conv, x_grad) - x_grad_out = checkpoint(self.upsample_grad, x_grad) - x_grad_out = checkpoint(self.grad_branch_output_conv, x_grad_out) - - x_out, a4, fea_grad_std = checkpoint(self.conjoin_sw, x1, ref_embedding, x_grad) - x_out = checkpoint(self.final_lr_conv, x_out) - x_out = checkpoint(self.upsample, x_out) - x_out = checkpoint(self.final_hr_conv2, x_out) - - if save_attentions: - self.attentions = [a1, a3, a4] - self.grad_fea_std = grad_fea_std.detach().cpu() - self.fea_grad_std = fea_grad_std.detach().cpu() - return x_grad_out, x_out, x_grad - - -class StackedSwitchGenerator(SwitchModelBase): - def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10): - super(StackedSwitchGenerator, self).__init__(init_temperature, 10000) - n_upscale = int(math.log(upscale, 2)) - self.nf = nf - - # processing the input embedding - self.reference_embedding = ReferenceImageBranch(nf) - - # Feature branch - self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False) - self.sw1 = SwitchWithReference(nf, xforms, init_temperature, has_ref=False) - self.sw2 = SwitchWithReference(nf, xforms, init_temperature, has_ref=False) - self.sw3 = SwitchWithReference(nf, xforms, init_temperature, has_ref=False) - self.switches = [self.sw1.switch, self.sw2.switch, self.sw3.switch] - - self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) - self.upsample = UpconvBlock(nf, nf // 2, block=ConvGnLelu, norm=False, activation=True, bias=True) - self.final_hr_conv1 = ConvGnLelu(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True) - self.final_hr_conv2 = ConvGnLelu(nf // 2, out_nc, kernel_size=3, norm=False, activation=False, bias=False) - - def forward(self, x, ref, ref_center, save_attentions=True): - # The attention_maps debugger outputs . Save that here. - self.lr = x.detach().cpu() - - ref_code = checkpoint(self.reference_embedding, ref, ref_center) - ref_embedding = ref_code.view(-1, ref_code.shape[1], 1, 1).repeat(1, 1, x.shape[2] // 8, x.shape[3] // 8) - - x = self.model_fea_conv(x) - x1, a1 = checkpoint(self.sw1, x, ref_embedding) - x2, a2 = checkpoint(self.sw2, x1, ref_embedding) - x3, a3 = checkpoint(self.sw3, x2, ref_embedding) - x_out = checkpoint(self.final_lr_conv, x3) - x_out = checkpoint(self.upsample, x_out) - x_out = checkpoint(self.final_hr_conv2, x_out) - - if save_attentions: - self.attentions = [a1, a3, a3] - return x_out, - - -class SSGDeep(SwitchModelBase): - def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10, recurrent=False): - super(SSGDeep, self).__init__(init_temperature, 10000) - n_upscale = int(math.log(upscale, 2)) - self.nf = nf - - # processing the input embedding - if recurrent: - self.recurrent = True - self.recurrent_process = ConvGnLelu(in_nc, nf, kernel_size=3, stride=2, norm=False, bias=True, activation=False) - self.recurrent_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False) - else: - self.recurrent = False - self.reference_embedding = ReferenceImageBranch(nf) - - # Feature branch - self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False) - self.sw1 = SwitchWithReference(nf, xforms, init_temperature, has_ref=False) - self.sw2 = SwitchWithReference(nf, xforms, init_temperature, has_ref=False) - - # Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague. - self.get_g_nopadding = ImageGradientNoPadding() - self.grad_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False, bias=False) - self.sw_grad = SwitchWithReference(nf, xforms // 2, init_temperature, has_ref=True) - self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) - self.upsample_grad = UpconvBlock(nf, nf // 2, block=ConvGnLelu, norm=False, activation=True, bias=False) - self.grad_branch_output_conv = ConvGnLelu(nf // 2, out_nc, kernel_size=1, norm=False, activation=False, bias=True) - - # Join branch (grad+fea) - self.conjoin_sw = SwitchWithReference(nf, xforms, init_temperature, has_ref=True) - self.sw4 = SwitchWithReference(nf, xforms, init_temperature, has_ref=False) - self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) - self.upsample = UpconvBlock(nf, nf // 2, block=ConvGnLelu, norm=False, activation=True, bias=True) - self.final_hr_conv1 = ConvGnLelu(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True) - self.final_hr_conv2 = ConvGnLelu(nf // 2, out_nc, kernel_size=3, norm=False, activation=False, bias=False) - self.switches = [self.sw1.switch, self.sw2.switch, self.sw_grad.switch, self.conjoin_sw.switch, self.sw4.switch] - - def forward(self, x, ref, ref_center, save_attentions=True, recurrent=None): - # The attention_maps debugger outputs . Save that here. - self.lr = x.detach().cpu() - - # If we're not saving attention, we also shouldn't be updating the attention norm. This is because the attention - # norm should only be getting updates with new data, not recurrent generator sampling. - for sw in self.switches: - sw.set_update_attention_norm(save_attentions) - - x_grad = self.get_g_nopadding(x) - ref_code = checkpoint(self.reference_embedding, ref, ref_center) - ref_embedding = ref_code.view(-1, ref_code.shape[1], 1, 1).repeat(1, 1, x.shape[2] // 8, x.shape[3] // 8) - - x = self.model_fea_conv(x) - if self.recurrent: - rec = self.recurrent_process(recurrent) - x, recurrent_std = self.recurrent_join(x, rec) - x1, a1 = checkpoint(self.sw1, x, ref_embedding) - x2, a2 = checkpoint(self.sw2, x1, ref_embedding) - - x_grad = self.grad_conv(x_grad) - x_grad, a3, grad_fea_std = checkpoint(self.sw_grad, x_grad, ref_embedding, x1) - x_grad = checkpoint(self.grad_lr_conv, x_grad) - x_grad_out = checkpoint(self.upsample_grad, x_grad) - x_grad_out = checkpoint(self.grad_branch_output_conv, x_grad_out) - - x3, a4, fea_grad_std = checkpoint(self.conjoin_sw, x2, ref_embedding, x_grad) - x_out, a5 = checkpoint(self.sw4, x3, ref_embedding) - x_out = checkpoint(self.final_lr_conv, x_out) - x_out = checkpoint(self.upsample, x_out) - x_out = checkpoint(self.final_hr_conv2, x_out) - - if save_attentions: - self.attentions = [a1, a2, a3, a4, a5] - self.grad_fea_std = grad_fea_std.detach().cpu() - self.fea_grad_std = fea_grad_std.detach().cpu() - return x_grad_out, x_out - - -class StackedSwitchGenerator5Layer(SwitchModelBase): - def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10): - super(StackedSwitchGenerator5Layer, self).__init__(init_temperature, 10000) - n_upscale = int(math.log(upscale, 2)) - self.nf = nf - - # processing the input embedding - self.reference_embedding = ReferenceImageBranch(nf) - - # Feature branch - self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False) - self.sw1 = SwitchWithReference(nf, xforms, init_temperature, has_ref=False) - self.sw2 = SwitchWithReference(nf, xforms // 2, init_temperature, has_ref=False) - self.sw3 = SwitchWithReference(nf, xforms // 2, init_temperature, has_ref=False) - self.sw4 = SwitchWithReference(nf, xforms // 2, init_temperature, has_ref=False) - self.sw5 = SwitchWithReference(nf, xforms, init_temperature, has_ref=False) - self.switches = [self.sw1.switch, self.sw2.switch, self.sw3.switch, self.sw4.switch, self.sw5.switch] - - self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) - self.upsample = UpconvBlock(nf, nf // 2, block=ConvGnLelu, norm=False, activation=True, bias=True) - self.final_hr_conv1 = ConvGnLelu(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True) - self.final_hr_conv2 = ConvGnLelu(nf // 2, out_nc, kernel_size=3, norm=False, activation=False, bias=False) - - def forward(self, x, ref, ref_center, save_attentions=True): - # The attention_maps debugger outputs . Save that here. - self.lr = x.detach().cpu() - - # If we're not saving attention, we also shouldn't be updating the attention norm. This is because the attention - # norm should only be getting updates with new data, not recurrent generator sampling. - for sw in self.switches: - sw.set_update_attention_norm(save_attentions) - - ref_code = checkpoint(self.reference_embedding, ref, ref_center) - ref_embedding = ref_code.view(-1, ref_code.shape[1], 1, 1).repeat(1, 1, x.shape[2] // 8, x.shape[3] // 8) - - x = self.model_fea_conv(x) - x1, a1 = checkpoint(self.sw1, x, ref_embedding) - x2, a2 = checkpoint(self.sw2, x1, ref_embedding) - x3, a3 = checkpoint(self.sw3, x2, ref_embedding) - x4, a4 = checkpoint(self.sw4, x3, ref_embedding) - x5, a5 = checkpoint(self.sw5, x4, ref_embedding) - x_out = checkpoint(self.final_lr_conv, x5) - x_out = checkpoint(self.upsample, x_out) - x_out = checkpoint(self.final_hr_conv2, x_out) - - if save_attentions: - self.attentions = [a1, a3, a3, a4, a5] - return x_out, - - -class StackedSwitchGenerator2xTeco(SwitchModelBase): - def __init__(self, nf, xforms=8, init_temperature=10): - super(StackedSwitchGenerator2xTeco, self).__init__(init_temperature, 10000) - self.nf = nf - - # processing the input embedding - self.reference_embedding = ReferenceImageBranch(nf) - - # Feature branch - self.model_fea_conv = ConvGnLelu(3, nf, kernel_size=7, norm=False, activation=False, bias=True) - self.model_recurrent_conv = ConvGnLelu(3, nf, kernel_size=3, stride=2, norm=False, activation=False, bias=True) - self.model_fea_recurrent_combine = ConvGnLelu(nf*2, nf, 1, activation=False, norm=False, bias=False) - self.sw1 = SwitchWithReference(nf, xforms, init_temperature, has_ref=False) - self.sw2 = SwitchWithReference(nf, xforms // 2, init_temperature, has_ref=False) - self.sw3 = SwitchWithReference(nf, xforms // 2, init_temperature, has_ref=False) - self.sw4 = SwitchWithReference(nf, xforms // 2, init_temperature, has_ref=False) - self.sw5 = SwitchWithReference(nf, xforms, init_temperature, has_ref=False) - self.switches = [self.sw1.switch, self.sw2.switch, self.sw3.switch, self.sw4.switch, self.sw5.switch] - - self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) - self.upsample = UpconvBlock(nf, nf // 2, block=ConvGnLelu, norm=False, activation=True, bias=True) - self.final_hr_conv1 = ConvGnLelu(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True) - self.final_hr_conv2 = ConvGnLelu(nf // 2, 3, kernel_size=3, norm=False, activation=False, bias=False) - - def forward(self, x, recurrent, ref, ref_center, save_attentions=True): - # The attention_maps debugger outputs . Save that here. - self.lr = x.detach().cpu() - - # If we're not saving attention, we also shouldn't be updating the attention norm. This is because the attention - # norm should only be getting updates with new data, not recurrent generator sampling. - for sw in self.switches: - sw.set_update_attention_norm(save_attentions) - - ref_code = checkpoint(self.reference_embedding, ref, ref_center) - ref_embedding = ref_code.view(-1, ref_code.shape[1], 1, 1).repeat(1, 1, x.shape[2] // 8, x.shape[3] // 8) - - x = self.model_fea_conv(x) - rec = self.model_recurrent_conv(recurrent) - x = self.model_fea_recurrent_combine(torch.cat([x, rec], dim=1)) - x1, a1 = checkpoint(self.sw1, x, ref_embedding) - x2, a2 = checkpoint(self.sw2, x1, ref_embedding) - x3, a3 = checkpoint(self.sw3, x2, ref_embedding) - x4, a4 = checkpoint(self.sw4, x3, ref_embedding) - x5, a5 = checkpoint(self.sw5, x4, ref_embedding) - x_out = checkpoint(self.final_lr_conv, x5) - x_out = checkpoint(self.upsample, x_out) - x_out = checkpoint(self.final_hr_conv2, x_out) - - if save_attentions: - self.attentions = [a1, a3, a3, a4, a5] - return x_out, - - -class SimplePyramidMultiplexer(nn.Module): - def __init__(self, nf, transforms): - super(SimplePyramidMultiplexer, self).__init__() - - # Blocks used to create the query - reductions = 3 - self.input_process = ConvGnSilu(nf, nf, activation=True, norm=False, bias=True) - self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(int(nf * 1.5 ** i), factor=1.5) - for i in range(reductions)]) - reduction_filters = int(nf * 1.5 ** reductions) - self.processing_blocks = nn.Sequential( - ConvGnSilu(reduction_filters, reduction_filters, kernel_size=3, activation=True, norm=True, bias=False), - ConvGnSilu(reduction_filters, reduction_filters, kernel_size=3, activation=True, norm=True, bias=False)) - self.expansion_blocks = nn.ModuleList([ExpansionBlock2(int(reduction_filters // (1.5 ** i)), factor=1.5) - for i in range(reductions)]) - - self.cbl1 = ConvGnSilu(nf, nf // 2, kernel_size=1, norm=False, bias=False) - self.cbl2 = ConvGnSilu(nf // 2, transforms, kernel_size=1, norm=False, bias=False) - - def forward(self, x): - q = self.input_process(x) - reduction_identities = [] - for b in self.reduction_blocks: - reduction_identities.append(q) - q = b(q) - q = self.processing_blocks(q) - for i, b in enumerate(self.expansion_blocks): - q = b(q, reduction_identities[-i - 1]) - q = self.cbl1(q) - q = self.cbl2(q) - return q - - -class SimplerSwitchWithReference(nn.Module): - def __init__(self, nf, num_transforms, init_temperature=10, has_ref=True): - super(SimplerSwitchWithReference, self).__init__() - self.nf = nf - self.transformation_counts = num_transforms - multiplx_fn = functools.partial(SimplePyramidMultiplexer, nf) - pretransform = functools.partial(ConvGnLelu, nf, int(nf*1.5), kernel_size=3, bias=False, norm=False, activation=True, weight_init_factor=.1) - transform_fn = functools.partial(ConvGnLelu, int(nf * 1.5), int(nf * 1.5), kernel_size=3, bias=False, norm=False, activation=True, weight_init_factor=.1) - posttransform = ConvGnLelu(int(nf*1.5), nf, kernel_size=3, bias=False, norm=False, activation=True, weight_init_factor=.1) - if has_ref: - self.ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3, final_norm=False, kernel_size=1, depth=2) - else: - self.ref_join = None - self.switch = ConfigurableSwitchComputer(nf, multiplx_fn, - pre_transform_block=pretransform, transform_block=transform_fn, - post_transform_block=posttransform, - attention_norm=True, - transform_count=self.transformation_counts, init_temp=init_temperature, - add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=False) - - def forward(self, x, ref=None): - if self.ref_join is not None: - branch, ref_std = self.ref_join(x, ref) - return self.switch(branch, identity=x) + (ref_std,) - else: - return self.switch(x, identity=x) - - -class SsgSimpler(SwitchModelBase): - def __init__(self, in_nc, out_nc, nf, xforms=8, init_temperature=10, recurrent=False): - super(SsgSimpler, self).__init__(init_temperature, 10000) - self.nf = nf - - # processing the input embedding - if recurrent: - self.recurrent = True - self.recurrent_process = ConvGnLelu(in_nc, nf, kernel_size=3, stride=2, norm=False, bias=True, activation=False) - self.recurrent_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.01, final_norm=False, kernel_size=1, depth=3, join=False) - else: - self.recurrent = False - - # Feature branch - self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False) - self.sw1 = SimplerSwitchWithReference(nf, xforms, init_temperature, has_ref=False) - self.sw2 = SimplerSwitchWithReference(nf, xforms, init_temperature, has_ref=False) - - # Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague. - self.get_g_nopadding = ImageGradientNoPadding() - self.grad_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False, bias=False) - self.sw_grad = SimplerSwitchWithReference(nf, xforms // 2, init_temperature, has_ref=True) - self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) - self.upsample_grad = UpconvBlock(nf, nf // 2, block=ConvGnLelu, norm=False, activation=True, bias=False) - self.grad_branch_output_conv = ConvGnLelu(nf // 2, out_nc, kernel_size=1, norm=False, activation=False, bias=True) - - # Join branch (grad+fea) - self.conjoin_sw = SimplerSwitchWithReference(nf, xforms, init_temperature, has_ref=True) - self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) - self.upsample = UpconvBlock(nf, nf // 2, block=ConvGnLelu, norm=False, activation=True, bias=True) - self.final_hr_conv1 = ConvGnLelu(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True) - self.final_hr_conv2 = ConvGnLelu(nf // 2, out_nc, kernel_size=3, norm=False, activation=False, bias=False) - self.switches = [self.sw1.switch, self.sw2.switch, self.sw_grad.switch, self.conjoin_sw.switch] - - def forward(self, x, save_attentions=True, recurrent=None): - # The attention_maps debugger outputs . Save that here. - self.lr = x.detach().cpu() - - # If we're not saving attention, we also shouldn't be updating the attention norm. This is because the attention - # norm should only be getting updates with new data, not recurrent generator sampling. - for sw in self.switches: - sw.set_update_attention_norm(save_attentions) - - x1 = self.model_fea_conv(x) - if self.recurrent: - rec = self.recurrent_process(recurrent) - x1, recurrent_std = self.recurrent_join(x1, rec) - x1, a1 = checkpoint(self.sw1, x1) - x2, a2 = checkpoint(self.sw2, x1) - - x_grad = self.get_g_nopadding(x) - x_grad = self.grad_conv(x_grad) - x_grad, a3, grad_fea_std = checkpoint(self.sw_grad, x_grad, x1) - x_grad = checkpoint(self.grad_lr_conv, x_grad) - x_grad_out = checkpoint(self.upsample_grad, x_grad) - x_grad_out = checkpoint(self.grad_branch_output_conv, x_grad_out) - - x3, a4, fea_grad_std = checkpoint(self.conjoin_sw, x2, x_grad) - x_out = checkpoint(self.final_lr_conv, x3) - x_out = checkpoint(self.upsample, x_out) - x_out = checkpoint(self.final_hr_conv2, x_out) - - if save_attentions: - self.attentions = [a1, a2, a3, a4] - self.grad_fea_std = grad_fea_std.detach().cpu() - self.fea_grad_std = fea_grad_std.detach().cpu() - return x_grad_out, x_out \ No newline at end of file diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 4885a35a..ec28cd96 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -614,131 +614,6 @@ class SwitchModelBase(nn.Module): return val -from models.archs.spinenet_arch import make_res_layer, BasicBlock -class BigMultiplexer(nn.Module): - def __init__(self, in_nc, nf, multiplexer_channels): - super(BigMultiplexer, self).__init__() - - self.spine = SpineNet(arch='96', output_level=[3], double_reduce_early=False) - self.spine_red_proc = ConvGnSilu(256, nf, kernel_size=1, activation=False, norm=False, bias=False) - self.fea_tail = ConvGnSilu(in_nc, nf, kernel_size=7, bias=True, norm=False, activation=False) - self.tail_proc = make_res_layer(BasicBlock, nf, nf, 2) - self.tail_join = ReferenceJoinBlock(nf) - - self.reduce = nn.Sequential(ConvGnSilu(nf, nf // 2, kernel_size=1, activation=True, norm=True, bias=False), - ConvGnSilu(nf // 2, multiplexer_channels, kernel_size=1, activation=False, norm=False, bias=False)) - - def forward(self, x, transformations): - s = self.spine(x)[0] - tail = self.fea_tail(x) - tail = self.tail_proc(tail) - q = F.interpolate(s, scale_factor=2, mode='nearest') - q = self.spine_red_proc(q) - q, _ = self.tail_join(q, tail) - return self.reduce(q) - - -class TheBigSwitch(SwitchModelBase): - def __init__(self, in_nc, nf, xforms=16, upscale=2, init_temperature=10): - super(TheBigSwitch, self).__init__(init_temperature, 10000) - self.nf = nf - self.transformation_counts = xforms - - self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=7, norm=False, activation=False) - - multiplx_fn = functools.partial(BigMultiplexer, in_nc, nf) - transform_fn = functools.partial(MultiConvBlock, nf, int(nf * 1.5), nf, kernel_size=3, depth=4, weight_init_factor=.1) - self.switch = ConfigurableSwitchComputer(nf, multiplx_fn, - pre_transform_block=None, transform_block=transform_fn, - attention_norm=True, - transform_count=self.transformation_counts, init_temp=init_temperature, - add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True, - anorm_multiplier=128) - self.switches = [self.switch] - - self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) - self.upsample = UpconvBlock(nf, nf // 2, block=ConvGnLelu, norm=False, activation=True, bias=True) - self.final_hr_conv1 = ConvGnLelu(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True) - self.final_hr_conv2 = ConvGnLelu(nf // 2, 3, kernel_size=3, norm=False, activation=False, bias=False) - - def forward(self, x, save_attentions=True): - # The attention_maps debugger outputs . Save that here. - self.lr = x.detach().cpu() - - # If we're not saving attention, we also shouldn't be updating the attention norm. This is because the attention - # norm should only be getting updates with new data, not recurrent generator sampling. - for sw in self.switches: - sw.set_update_attention_norm(save_attentions) - - x1 = self.model_fea_conv(x) - x1, a1 = self.switch(x1, att_in=x, do_checkpointing=True) - x_out = checkpoint(self.final_lr_conv, x1) - x_out = checkpoint(self.upsample, x_out) - x_out = checkpoint(self.final_hr_conv2, x_out) - - if save_attentions: - self.attentions = [a1] - return x_out, - - -class ArtistMultiplexer(nn.Module): - def __init__(self, in_nc, nf, multiplexer_channels): - super(ArtistMultiplexer, self).__init__() - - self.spine = SpineNet(arch='96', output_level=[3], double_reduce_early=False) - self.spine_red_proc = ConvGnSilu(256, nf, kernel_size=1, activation=False, norm=False, bias=False) - self.fea_tail = ConvGnSilu(in_nc, nf, kernel_size=7, bias=True, norm=False, activation=False) - self.tail_proc = make_res_layer(BasicBlock, nf, nf, 2) - self.tail_join = ReferenceJoinBlock(nf) - - self.reduce = ConvGnSilu(nf, nf // 2, kernel_size=1, activation=True, norm=True, bias=False) - self.last_process = ConvGnSilu(nf // 2, nf // 2, kernel_size=1, activation=True, norm=False, bias=False) - self.to_attention = ConvGnSilu(nf // 2, multiplexer_channels, kernel_size=1, activation=False, norm=False, bias=False) - - def forward(self, x, transformations): - s = self.spine(x)[0] - tail = self.fea_tail(x) - tail = self.tail_proc(tail) - q = F.interpolate(s, scale_factor=2, mode='nearest') - q = self.spine_red_proc(q) - q, _ = self.tail_join(q, tail) - q = self.reduce(q) - q = F.interpolate(q, scale_factor=2, mode='nearest') - return self.to_attention(self.last_process(q)) - - -class ArtistGen(SwitchModelBase): - def __init__(self, in_nc, nf, xforms=16, upscale=2, init_temperature=10): - super(ArtistGen, self).__init__(init_temperature, 10000) - self.nf = nf - self.transformation_counts = xforms - - multiplx_fn = functools.partial(ArtistMultiplexer, in_nc, nf) - transform_fn = functools.partial(MultiConvBlock, in_nc, int(in_nc * 2), in_nc, kernel_size=3, depth=4, weight_init_factor=.1) - self.switch = ConfigurableSwitchComputer(nf, multiplx_fn, - pre_transform_block=None, transform_block=transform_fn, - attention_norm=True, - transform_count=self.transformation_counts, init_temp=init_temperature, - add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True, - anorm_multiplier=128, post_switch_conv=False) - self.switches = [self.switch] - - def forward(self, x, save_attentions=True): - # The attention_maps debugger outputs . Save that here. - self.lr = x.detach().cpu() - - # If we're not saving attention, we also shouldn't be updating the attention norm. This is because the attention - # norm should only be getting updates with new data, not recurrent generator sampling. - for sw in self.switches: - sw.set_update_attention_norm(save_attentions) - - up = F.interpolate(x, scale_factor=2, mode="bicubic") - out, a1, att_logits = self.switch(up, att_in=x, do_checkpointing=True, output_att_logits=True) - - if save_attentions: - self.attentions = [a1] - return out, att_logits.permute(0,3,1,2) - if __name__ == '__main__': tbs = TheBigSwitch(3, 64) x = torch.randn(4,3,64,64) diff --git a/codes/models/networks.py b/codes/models/networks.py index 176a6d2b..ea266a23 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -12,14 +12,12 @@ import models.archs.DiscriminatorResnet_arch_passthrough as DiscriminatorResnet_ import models.archs.RRDBNet_arch as RRDBNet_arch import models.archs.SPSR_arch as spsr import models.archs.SRResNet_arch as SRResNet_arch -import models.archs.StructuredSwitchedGenerator as ssg import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch import models.archs.discriminator_vgg_arch as SRGAN_arch import models.archs.feature_arch as feature_arch import models.archs.panet.panet as panet import models.archs.rcan as rcan -from models.archs.ChainedEmbeddingGen import ChainedEmbeddingGen, ChainedEmbeddingGenWithStructure, \ - StructuredChainedEmbeddingGenWithBypass, MultifacetedChainedEmbeddingGen +import models.archs.ChainedEmbeddingGen as chained logger = logging.getLogger('base') @@ -72,76 +70,15 @@ def define_G(opt, net_key='network_G', scale=None): nb=opt_net['nb'], upscale=opt_net['scale']) elif which_model == "spsr_switched": netG = spsr.SwitchedSpsr(in_nc=3, nf=opt_net['nf'], upscale=opt_net['scale'], init_temperature=opt_net['temperature']) - elif which_model == "spsr5": - xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 - netG = spsr.Spsr5(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], - multiplexer_reductions=opt_net['multiplexer_reductions'] if 'multiplexer_reductions' in opt_net.keys() else 2, - init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) - elif which_model == "spsr6": - xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 - netG = spsr.Spsr6(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], - multiplexer_reductions=opt_net['multiplexer_reductions'] if 'multiplexer_reductions' in opt_net.keys() else 3, - init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) elif which_model == "spsr7": recurrent = opt_net['recurrent'] if 'recurrent' in opt_net.keys() else False xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 netG = spsr.Spsr7(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], multiplexer_reductions=opt_net['multiplexer_reductions'] if 'multiplexer_reductions' in opt_net.keys() else 3, init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10, recurrent=recurrent) - elif which_model == "spsr9": - xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 - netG = spsr.Spsr9(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], - multiplexer_reductions=opt_net['multiplexer_reductions'] if 'multiplexer_reductions' in opt_net.keys() else 3, - init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) - elif which_model == "ssgr1": - recurrent = opt_net['recurrent'] if 'recurrent' in opt_net.keys() else False - xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 - netG = ssg.SSGr1(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], - init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10, recurrent=recurrent) - elif which_model == 'stacked_switches': - xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 - in_nc = opt_net['in_nc'] if 'in_nc' in opt_net.keys() else 3 - netG = ssg.StackedSwitchGenerator(in_nc=in_nc, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], - init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) - elif which_model == 'stacked_switches_5lyr': - xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 - in_nc = opt_net['in_nc'] if 'in_nc' in opt_net.keys() else 3 - netG = ssg.StackedSwitchGenerator5Layer(in_nc=in_nc, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], - init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) - elif which_model == 'ssg_deep': - xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 - netG = ssg.SSGDeep(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], - init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) - elif which_model == 'ssg_simpler': - xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 - netG = ssg.SsgSimpler(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, - init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) - elif which_model == 'ssg_teco': - netG = ssg.StackedSwitchGenerator2xTeco(nf=opt_net['nf'], xforms=opt_net['num_transforms'], init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) - elif which_model == 'big_switch': - netG = SwitchedGen_arch.TheBigSwitch(opt_net['in_nc'], nf=opt_net['nf'], xforms=opt_net['num_transforms'], upscale=opt_net['scale'], - init_temperature=opt_net['temperature']) - elif which_model == 'artist': - netG = SwitchedGen_arch.ArtistGen(opt_net['in_nc'], nf=opt_net['nf'], xforms=opt_net['num_transforms'], upscale=opt_net['scale'], - init_temperature=opt_net['temperature']) - elif which_model == 'chained_gen': - in_nc = opt_net['in_nc'] if 'in_nc' in opt_net.keys() else 3 - netG = ChainedEmbeddingGen(depth=opt_net['depth'], in_nc=in_nc) - elif which_model == 'chained_gen_structured': - rec = opt_net['recurrent'] if 'recurrent' in opt_net.keys() else False - recnf = opt_net['recurrent_nf'] if 'recurrent_nf' in opt_net.keys() else 3 - recstd = opt_net['recurrent_stride'] if 'recurrent_stride' in opt_net.keys() else 2 - in_nc = opt_net['in_nc'] if 'in_nc' in opt_net.keys() else 3 - netG = ChainedEmbeddingGenWithStructure(depth=opt_net['depth'], recurrent=rec, recurrent_nf=recnf, recurrent_stride=recstd, in_nc=in_nc) - elif which_model == 'chained_gen_structured_with_bypass': - rec = opt_net['recurrent'] if 'recurrent' in opt_net.keys() else False - recnf = opt_net['recurrent_nf'] if 'recurrent_nf' in opt_net.keys() else 3 - recstd = opt_net['recurrent_stride'] if 'recurrent_stride' in opt_net.keys() else 2 - bypass_bias = opt_net['bypass_bias'] if 'bypass_bias' in opt_net.keys() else 0 - netG = StructuredChainedEmbeddingGenWithBypass(depth=opt_net['depth'], recurrent=rec, recurrent_nf=recnf, recurrent_stride=recstd, bypass_bias=bypass_bias) elif which_model == 'multifaceted_chained': scale = opt_net['scale'] if 'scale' in opt_net.keys() else 2 - netG = MultifacetedChainedEmbeddingGen(depth=opt_net['depth'], scale=scale) + netG = chained.MultifacetedChainedEmbeddingGen(depth=opt_net['depth'], scale=scale) elif which_model == "flownet2": from models.flownet2.models import FlowNet2 ld = torch.load(opt_net['load_path'])