Report the standard deviation of ref branches
This patch also ups the contribution
This commit is contained in:
parent
668bfbff6d
commit
9e5aa166de
|
@ -361,7 +361,7 @@ class RefJoiner(nn.Module):
|
|||
super(RefJoiner, self).__init__()
|
||||
self.lin1 = nn.Linear(512, 256)
|
||||
self.lin2 = nn.Linear(256, nf)
|
||||
self.join = ReferenceJoinBlock(nf, residual_weight_init_factor=.05, norm=False)
|
||||
self.join = ReferenceJoinBlock(nf, residual_weight_init_factor=.1)
|
||||
|
||||
def forward(self, x, ref):
|
||||
ref = self.lin1(ref)
|
||||
|
@ -374,11 +374,11 @@ class RefJoiner(nn.Module):
|
|||
class ModuleWithRef(nn.Module):
|
||||
def __init__(self, nf, mcnv, *args):
|
||||
super(ModuleWithRef, self).__init__()
|
||||
self.join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False)
|
||||
self.join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2)
|
||||
self.multi = mcnv(*args)
|
||||
|
||||
def forward(self, x, ref):
|
||||
out = self.join(x, ref)
|
||||
out, _ = self.join(x, ref)
|
||||
return self.multi(out)
|
||||
|
||||
|
||||
|
@ -402,7 +402,7 @@ class SwitchedSpsrWithRef2(nn.Module):
|
|||
|
||||
# Feature branch
|
||||
self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
|
||||
self.noise_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.01, norm=False)
|
||||
self.noise_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.1)
|
||||
self.ref_join1 = RefJoiner(nf)
|
||||
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||
|
@ -421,9 +421,9 @@ class SwitchedSpsrWithRef2(nn.Module):
|
|||
# Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague.
|
||||
self.get_g_nopadding = ImageGradientNoPadding()
|
||||
self.grad_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False)
|
||||
self.noise_ref_join_grad = ReferenceJoinBlock(nf, residual_weight_init_factor=.01, norm=False)
|
||||
self.noise_ref_join_grad = ReferenceJoinBlock(nf, residual_weight_init_factor=.1)
|
||||
self.ref_join3 = RefJoiner(nf)
|
||||
self.grad_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False, final_norm=False)
|
||||
self.grad_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3, final_norm=False)
|
||||
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
|
@ -436,8 +436,8 @@ class SwitchedSpsrWithRef2(nn.Module):
|
|||
|
||||
# Join branch (grad+fea)
|
||||
self.ref_join4 = RefJoiner(nf)
|
||||
self.noise_ref_join_conjoin = ReferenceJoinBlock(nf, residual_weight_init_factor=.01, norm=False)
|
||||
self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False)
|
||||
self.noise_ref_join_conjoin = ReferenceJoinBlock(nf, residual_weight_init_factor=.1)
|
||||
self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3)
|
||||
self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
|
@ -453,42 +453,55 @@ class SwitchedSpsrWithRef2(nn.Module):
|
|||
self.final_temperature_step = 10000
|
||||
|
||||
def forward(self, x, ref, center_coord):
|
||||
ref_stds = []
|
||||
noise_stds = []
|
||||
|
||||
x_grad = self.get_g_nopadding(x)
|
||||
ref = self.reference_processor(ref, center_coord)
|
||||
|
||||
x = self.model_fea_conv(x)
|
||||
x1 = x
|
||||
x1 = self.ref_join1(x1, ref)
|
||||
x1, rstd = self.ref_join1(x1, ref)
|
||||
x1, a1 = self.sw1(x1, True, identity=x)
|
||||
ref_stds.append(rstd)
|
||||
|
||||
x2 = x1
|
||||
x2 = self.noise_ref_join(x2, torch.randn_like(x2))
|
||||
x2 = self.ref_join2(x2, ref)
|
||||
x2, nstd = self.noise_ref_join(x2, torch.randn_like(x2))
|
||||
x2, rstd = self.ref_join2(x2, ref)
|
||||
x2, a2 = self.sw2(x2, True, identity=x1)
|
||||
noise_stds.append(nstd)
|
||||
ref_stds.append(rstd)
|
||||
|
||||
x_grad = self.grad_conv(x_grad)
|
||||
x_grad_identity = x_grad
|
||||
x_grad = self.noise_ref_join_grad(x_grad, torch.randn_like(x_grad))
|
||||
x_grad = self.ref_join3(x_grad, ref)
|
||||
x_grad = self.grad_ref_join(x_grad, x1)
|
||||
x_grad, nstd = self.noise_ref_join_grad(x_grad, torch.randn_like(x_grad))
|
||||
x_grad, rstd = self.ref_join3(x_grad, ref)
|
||||
x_grad, grad_fea_std = self.grad_ref_join(x_grad, x1)
|
||||
x_grad, a3 = self.sw_grad(x_grad, True, identity=x_grad_identity)
|
||||
x_grad = self.grad_lr_conv(x_grad)
|
||||
x_grad = self.grad_lr_conv2(x_grad)
|
||||
x_grad_out = self.upsample_grad(x_grad)
|
||||
x_grad_out = self.grad_branch_output_conv(x_grad_out)
|
||||
noise_stds.append(nstd)
|
||||
ref_stds.append(rstd)
|
||||
|
||||
x_out = x2
|
||||
x_out = self.noise_ref_join_conjoin(x_out, torch.randn_like(x_out))
|
||||
x_out = self.ref_join4(x_out, ref)
|
||||
x_out = self.conjoin_ref_join(x_out, x_grad)
|
||||
x_out, nstd = self.noise_ref_join_conjoin(x_out, torch.randn_like(x_out))
|
||||
x_out, rstd = self.ref_join4(x_out, ref)
|
||||
x_out, fea_grad_std = self.conjoin_ref_join(x_out, x_grad)
|
||||
x_out, a4 = self.conjoin_sw(x_out, True, identity=x2)
|
||||
x_out = self.final_lr_conv(x_out)
|
||||
x_out = self.upsample(x_out)
|
||||
x_out = self.final_hr_conv1(x_out)
|
||||
x_out = self.final_hr_conv2(x_out)
|
||||
noise_stds.append(nstd)
|
||||
ref_stds.append(rstd)
|
||||
|
||||
self.attentions = [a1, a2, a3, a4]
|
||||
|
||||
self.noise_stds = torch.stack(noise_stds).mean().detach().cpu()
|
||||
self.ref_stds = torch.stack(ref_stds).mean().detach().cpu()
|
||||
self.grad_fea_std = grad_fea_std.detach().cpu()
|
||||
self.fea_grad_std = fea_grad_std.detach().cpu()
|
||||
return x_grad_out, x_out, x_grad
|
||||
|
||||
def set_temperature(self, temp):
|
||||
|
@ -509,7 +522,11 @@ class SwitchedSpsrWithRef2(nn.Module):
|
|||
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
|
||||
means = [i[0] for i in mean_hists]
|
||||
hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists]
|
||||
val = {"switch_temperature": temp}
|
||||
val = {"switch_temperature": temp,
|
||||
"reference_branch_std_dev": self.ref_stds,
|
||||
"noise_branch_std_dev": self.noise_stds,
|
||||
"grad_branch_feat_intg_std_dev": self.grad_fea_std,
|
||||
"conjoin_branch_grad_intg_std_dev": self.fea_grad_std}
|
||||
for i in range(len(means)):
|
||||
val["switch_%i_specificity" % (i,)] = means[i]
|
||||
val["switch_%i_histogram" % (i,)] = hists[i]
|
||||
|
|
|
@ -456,17 +456,17 @@ class ConjoinBlock(nn.Module):
|
|||
|
||||
# Designed explicitly to join a mainline trunk with reference data. Implemented as a residual branch.
|
||||
class ReferenceJoinBlock(nn.Module):
|
||||
def __init__(self, nf, residual_weight_init_factor=1, norm=False, block=ConvGnLelu, final_norm=False):
|
||||
def __init__(self, nf, residual_weight_init_factor=1, block=ConvGnLelu, final_norm=False):
|
||||
super(ReferenceJoinBlock, self).__init__()
|
||||
self.branch = MultiConvBlock(nf * 2, nf + nf // 2, nf, kernel_size=3, depth=3,
|
||||
scale_init=residual_weight_init_factor, norm=norm,
|
||||
scale_init=residual_weight_init_factor, norm=False,
|
||||
weight_init_factor=residual_weight_init_factor)
|
||||
self.join_conv = block(nf, nf, norm=final_norm, bias=False, activation=True)
|
||||
|
||||
def forward(self, x, ref):
|
||||
joined = torch.cat([x, ref], dim=1)
|
||||
branch = self.branch(joined)
|
||||
return self.join_conv(x + branch)
|
||||
return self.join_conv(x + branch), torch.std(branch)
|
||||
|
||||
|
||||
# Basic convolutional upsampling block that uses interpolate.
|
||||
|
|
Loading…
Reference in New Issue
Block a user