From 9e5aa166de2ccb2b30ad2e57878031b5cb42a38b Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 10 Sep 2020 16:34:41 -0600 Subject: [PATCH] Report the standard deviation of ref branches This patch also ups the contribution --- codes/models/archs/SPSR_arch.py | 55 +++++++++++++++++++++------------ codes/models/archs/arch_util.py | 6 ++-- 2 files changed, 39 insertions(+), 22 deletions(-) diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index d9c59d1f..e22fcc64 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -361,7 +361,7 @@ class RefJoiner(nn.Module): super(RefJoiner, self).__init__() self.lin1 = nn.Linear(512, 256) self.lin2 = nn.Linear(256, nf) - self.join = ReferenceJoinBlock(nf, residual_weight_init_factor=.05, norm=False) + self.join = ReferenceJoinBlock(nf, residual_weight_init_factor=.1) def forward(self, x, ref): ref = self.lin1(ref) @@ -374,11 +374,11 @@ class RefJoiner(nn.Module): class ModuleWithRef(nn.Module): def __init__(self, nf, mcnv, *args): super(ModuleWithRef, self).__init__() - self.join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False) + self.join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2) self.multi = mcnv(*args) def forward(self, x, ref): - out = self.join(x, ref) + out, _ = self.join(x, ref) return self.multi(out) @@ -402,7 +402,7 @@ class SwitchedSpsrWithRef2(nn.Module): # Feature branch self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False) - self.noise_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.01, norm=False) + self.noise_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.1) self.ref_join1 = RefJoiner(nf) self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, pre_transform_block=pretransform_fn, transform_block=transform_fn, @@ -421,9 +421,9 @@ class SwitchedSpsrWithRef2(nn.Module): # Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague. self.get_g_nopadding = ImageGradientNoPadding() self.grad_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False) - self.noise_ref_join_grad = ReferenceJoinBlock(nf, residual_weight_init_factor=.01, norm=False) + self.noise_ref_join_grad = ReferenceJoinBlock(nf, residual_weight_init_factor=.1) self.ref_join3 = RefJoiner(nf) - self.grad_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False, final_norm=False) + self.grad_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3, final_norm=False) self.sw_grad = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, pre_transform_block=pretransform_fn, transform_block=transform_fn, attention_norm=True, @@ -436,8 +436,8 @@ class SwitchedSpsrWithRef2(nn.Module): # Join branch (grad+fea) self.ref_join4 = RefJoiner(nf) - self.noise_ref_join_conjoin = ReferenceJoinBlock(nf, residual_weight_init_factor=.01, norm=False) - self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False) + self.noise_ref_join_conjoin = ReferenceJoinBlock(nf, residual_weight_init_factor=.1) + self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3) self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, pre_transform_block=pretransform_fn, transform_block=transform_fn, attention_norm=True, @@ -453,42 +453,55 @@ class SwitchedSpsrWithRef2(nn.Module): self.final_temperature_step = 10000 def forward(self, x, ref, center_coord): + ref_stds = [] + noise_stds = [] + x_grad = self.get_g_nopadding(x) ref = self.reference_processor(ref, center_coord) x = self.model_fea_conv(x) x1 = x - x1 = self.ref_join1(x1, ref) + x1, rstd = self.ref_join1(x1, ref) x1, a1 = self.sw1(x1, True, identity=x) + ref_stds.append(rstd) x2 = x1 - x2 = self.noise_ref_join(x2, torch.randn_like(x2)) - x2 = self.ref_join2(x2, ref) + x2, nstd = self.noise_ref_join(x2, torch.randn_like(x2)) + x2, rstd = self.ref_join2(x2, ref) x2, a2 = self.sw2(x2, True, identity=x1) + noise_stds.append(nstd) + ref_stds.append(rstd) x_grad = self.grad_conv(x_grad) x_grad_identity = x_grad - x_grad = self.noise_ref_join_grad(x_grad, torch.randn_like(x_grad)) - x_grad = self.ref_join3(x_grad, ref) - x_grad = self.grad_ref_join(x_grad, x1) + x_grad, nstd = self.noise_ref_join_grad(x_grad, torch.randn_like(x_grad)) + x_grad, rstd = self.ref_join3(x_grad, ref) + x_grad, grad_fea_std = self.grad_ref_join(x_grad, x1) x_grad, a3 = self.sw_grad(x_grad, True, identity=x_grad_identity) x_grad = self.grad_lr_conv(x_grad) x_grad = self.grad_lr_conv2(x_grad) x_grad_out = self.upsample_grad(x_grad) x_grad_out = self.grad_branch_output_conv(x_grad_out) + noise_stds.append(nstd) + ref_stds.append(rstd) x_out = x2 - x_out = self.noise_ref_join_conjoin(x_out, torch.randn_like(x_out)) - x_out = self.ref_join4(x_out, ref) - x_out = self.conjoin_ref_join(x_out, x_grad) + x_out, nstd = self.noise_ref_join_conjoin(x_out, torch.randn_like(x_out)) + x_out, rstd = self.ref_join4(x_out, ref) + x_out, fea_grad_std = self.conjoin_ref_join(x_out, x_grad) x_out, a4 = self.conjoin_sw(x_out, True, identity=x2) x_out = self.final_lr_conv(x_out) x_out = self.upsample(x_out) x_out = self.final_hr_conv1(x_out) x_out = self.final_hr_conv2(x_out) + noise_stds.append(nstd) + ref_stds.append(rstd) self.attentions = [a1, a2, a3, a4] - + self.noise_stds = torch.stack(noise_stds).mean().detach().cpu() + self.ref_stds = torch.stack(ref_stds).mean().detach().cpu() + self.grad_fea_std = grad_fea_std.detach().cpu() + self.fea_grad_std = fea_grad_std.detach().cpu() return x_grad_out, x_out, x_grad def set_temperature(self, temp): @@ -509,7 +522,11 @@ class SwitchedSpsrWithRef2(nn.Module): mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] means = [i[0] for i in mean_hists] hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists] - val = {"switch_temperature": temp} + val = {"switch_temperature": temp, + "reference_branch_std_dev": self.ref_stds, + "noise_branch_std_dev": self.noise_stds, + "grad_branch_feat_intg_std_dev": self.grad_fea_std, + "conjoin_branch_grad_intg_std_dev": self.fea_grad_std} for i in range(len(means)): val["switch_%i_specificity" % (i,)] = means[i] val["switch_%i_histogram" % (i,)] = hists[i] diff --git a/codes/models/archs/arch_util.py b/codes/models/archs/arch_util.py index 9e808e40..22bde274 100644 --- a/codes/models/archs/arch_util.py +++ b/codes/models/archs/arch_util.py @@ -456,17 +456,17 @@ class ConjoinBlock(nn.Module): # Designed explicitly to join a mainline trunk with reference data. Implemented as a residual branch. class ReferenceJoinBlock(nn.Module): - def __init__(self, nf, residual_weight_init_factor=1, norm=False, block=ConvGnLelu, final_norm=False): + def __init__(self, nf, residual_weight_init_factor=1, block=ConvGnLelu, final_norm=False): super(ReferenceJoinBlock, self).__init__() self.branch = MultiConvBlock(nf * 2, nf + nf // 2, nf, kernel_size=3, depth=3, - scale_init=residual_weight_init_factor, norm=norm, + scale_init=residual_weight_init_factor, norm=False, weight_init_factor=residual_weight_init_factor) self.join_conv = block(nf, nf, norm=final_norm, bias=False, activation=True) def forward(self, x, ref): joined = torch.cat([x, ref], dim=1) branch = self.branch(joined) - return self.join_conv(x + branch) + return self.join_conv(x + branch), torch.std(branch) # Basic convolutional upsampling block that uses interpolate.