Report the standard deviation of ref branches

This patch also ups the contribution
This commit is contained in:
James Betker 2020-09-10 16:34:41 -06:00
parent 668bfbff6d
commit 9e5aa166de
2 changed files with 39 additions and 22 deletions

View File

@ -361,7 +361,7 @@ class RefJoiner(nn.Module):
super(RefJoiner, self).__init__() super(RefJoiner, self).__init__()
self.lin1 = nn.Linear(512, 256) self.lin1 = nn.Linear(512, 256)
self.lin2 = nn.Linear(256, nf) self.lin2 = nn.Linear(256, nf)
self.join = ReferenceJoinBlock(nf, residual_weight_init_factor=.05, norm=False) self.join = ReferenceJoinBlock(nf, residual_weight_init_factor=.1)
def forward(self, x, ref): def forward(self, x, ref):
ref = self.lin1(ref) ref = self.lin1(ref)
@ -374,11 +374,11 @@ class RefJoiner(nn.Module):
class ModuleWithRef(nn.Module): class ModuleWithRef(nn.Module):
def __init__(self, nf, mcnv, *args): def __init__(self, nf, mcnv, *args):
super(ModuleWithRef, self).__init__() super(ModuleWithRef, self).__init__()
self.join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False) self.join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2)
self.multi = mcnv(*args) self.multi = mcnv(*args)
def forward(self, x, ref): def forward(self, x, ref):
out = self.join(x, ref) out, _ = self.join(x, ref)
return self.multi(out) return self.multi(out)
@ -402,7 +402,7 @@ class SwitchedSpsrWithRef2(nn.Module):
# Feature branch # Feature branch
self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False) self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
self.noise_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.01, norm=False) self.noise_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.1)
self.ref_join1 = RefJoiner(nf) self.ref_join1 = RefJoiner(nf)
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn, transform_block=transform_fn, pre_transform_block=pretransform_fn, transform_block=transform_fn,
@ -421,9 +421,9 @@ class SwitchedSpsrWithRef2(nn.Module):
# Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague. # Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague.
self.get_g_nopadding = ImageGradientNoPadding() self.get_g_nopadding = ImageGradientNoPadding()
self.grad_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False) self.grad_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False)
self.noise_ref_join_grad = ReferenceJoinBlock(nf, residual_weight_init_factor=.01, norm=False) self.noise_ref_join_grad = ReferenceJoinBlock(nf, residual_weight_init_factor=.1)
self.ref_join3 = RefJoiner(nf) self.ref_join3 = RefJoiner(nf)
self.grad_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False, final_norm=False) self.grad_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3, final_norm=False)
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, self.sw_grad = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn, transform_block=transform_fn, pre_transform_block=pretransform_fn, transform_block=transform_fn,
attention_norm=True, attention_norm=True,
@ -436,8 +436,8 @@ class SwitchedSpsrWithRef2(nn.Module):
# Join branch (grad+fea) # Join branch (grad+fea)
self.ref_join4 = RefJoiner(nf) self.ref_join4 = RefJoiner(nf)
self.noise_ref_join_conjoin = ReferenceJoinBlock(nf, residual_weight_init_factor=.01, norm=False) self.noise_ref_join_conjoin = ReferenceJoinBlock(nf, residual_weight_init_factor=.1)
self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False) self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3)
self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn, transform_block=transform_fn, pre_transform_block=pretransform_fn, transform_block=transform_fn,
attention_norm=True, attention_norm=True,
@ -453,42 +453,55 @@ class SwitchedSpsrWithRef2(nn.Module):
self.final_temperature_step = 10000 self.final_temperature_step = 10000
def forward(self, x, ref, center_coord): def forward(self, x, ref, center_coord):
ref_stds = []
noise_stds = []
x_grad = self.get_g_nopadding(x) x_grad = self.get_g_nopadding(x)
ref = self.reference_processor(ref, center_coord) ref = self.reference_processor(ref, center_coord)
x = self.model_fea_conv(x) x = self.model_fea_conv(x)
x1 = x x1 = x
x1 = self.ref_join1(x1, ref) x1, rstd = self.ref_join1(x1, ref)
x1, a1 = self.sw1(x1, True, identity=x) x1, a1 = self.sw1(x1, True, identity=x)
ref_stds.append(rstd)
x2 = x1 x2 = x1
x2 = self.noise_ref_join(x2, torch.randn_like(x2)) x2, nstd = self.noise_ref_join(x2, torch.randn_like(x2))
x2 = self.ref_join2(x2, ref) x2, rstd = self.ref_join2(x2, ref)
x2, a2 = self.sw2(x2, True, identity=x1) x2, a2 = self.sw2(x2, True, identity=x1)
noise_stds.append(nstd)
ref_stds.append(rstd)
x_grad = self.grad_conv(x_grad) x_grad = self.grad_conv(x_grad)
x_grad_identity = x_grad x_grad_identity = x_grad
x_grad = self.noise_ref_join_grad(x_grad, torch.randn_like(x_grad)) x_grad, nstd = self.noise_ref_join_grad(x_grad, torch.randn_like(x_grad))
x_grad = self.ref_join3(x_grad, ref) x_grad, rstd = self.ref_join3(x_grad, ref)
x_grad = self.grad_ref_join(x_grad, x1) x_grad, grad_fea_std = self.grad_ref_join(x_grad, x1)
x_grad, a3 = self.sw_grad(x_grad, True, identity=x_grad_identity) x_grad, a3 = self.sw_grad(x_grad, True, identity=x_grad_identity)
x_grad = self.grad_lr_conv(x_grad) x_grad = self.grad_lr_conv(x_grad)
x_grad = self.grad_lr_conv2(x_grad) x_grad = self.grad_lr_conv2(x_grad)
x_grad_out = self.upsample_grad(x_grad) x_grad_out = self.upsample_grad(x_grad)
x_grad_out = self.grad_branch_output_conv(x_grad_out) x_grad_out = self.grad_branch_output_conv(x_grad_out)
noise_stds.append(nstd)
ref_stds.append(rstd)
x_out = x2 x_out = x2
x_out = self.noise_ref_join_conjoin(x_out, torch.randn_like(x_out)) x_out, nstd = self.noise_ref_join_conjoin(x_out, torch.randn_like(x_out))
x_out = self.ref_join4(x_out, ref) x_out, rstd = self.ref_join4(x_out, ref)
x_out = self.conjoin_ref_join(x_out, x_grad) x_out, fea_grad_std = self.conjoin_ref_join(x_out, x_grad)
x_out, a4 = self.conjoin_sw(x_out, True, identity=x2) x_out, a4 = self.conjoin_sw(x_out, True, identity=x2)
x_out = self.final_lr_conv(x_out) x_out = self.final_lr_conv(x_out)
x_out = self.upsample(x_out) x_out = self.upsample(x_out)
x_out = self.final_hr_conv1(x_out) x_out = self.final_hr_conv1(x_out)
x_out = self.final_hr_conv2(x_out) x_out = self.final_hr_conv2(x_out)
noise_stds.append(nstd)
ref_stds.append(rstd)
self.attentions = [a1, a2, a3, a4] self.attentions = [a1, a2, a3, a4]
self.noise_stds = torch.stack(noise_stds).mean().detach().cpu()
self.ref_stds = torch.stack(ref_stds).mean().detach().cpu()
self.grad_fea_std = grad_fea_std.detach().cpu()
self.fea_grad_std = fea_grad_std.detach().cpu()
return x_grad_out, x_out, x_grad return x_grad_out, x_out, x_grad
def set_temperature(self, temp): def set_temperature(self, temp):
@ -509,7 +522,11 @@ class SwitchedSpsrWithRef2(nn.Module):
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
means = [i[0] for i in mean_hists] means = [i[0] for i in mean_hists]
hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists] hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists]
val = {"switch_temperature": temp} val = {"switch_temperature": temp,
"reference_branch_std_dev": self.ref_stds,
"noise_branch_std_dev": self.noise_stds,
"grad_branch_feat_intg_std_dev": self.grad_fea_std,
"conjoin_branch_grad_intg_std_dev": self.fea_grad_std}
for i in range(len(means)): for i in range(len(means)):
val["switch_%i_specificity" % (i,)] = means[i] val["switch_%i_specificity" % (i,)] = means[i]
val["switch_%i_histogram" % (i,)] = hists[i] val["switch_%i_histogram" % (i,)] = hists[i]

View File

@ -456,17 +456,17 @@ class ConjoinBlock(nn.Module):
# Designed explicitly to join a mainline trunk with reference data. Implemented as a residual branch. # Designed explicitly to join a mainline trunk with reference data. Implemented as a residual branch.
class ReferenceJoinBlock(nn.Module): class ReferenceJoinBlock(nn.Module):
def __init__(self, nf, residual_weight_init_factor=1, norm=False, block=ConvGnLelu, final_norm=False): def __init__(self, nf, residual_weight_init_factor=1, block=ConvGnLelu, final_norm=False):
super(ReferenceJoinBlock, self).__init__() super(ReferenceJoinBlock, self).__init__()
self.branch = MultiConvBlock(nf * 2, nf + nf // 2, nf, kernel_size=3, depth=3, self.branch = MultiConvBlock(nf * 2, nf + nf // 2, nf, kernel_size=3, depth=3,
scale_init=residual_weight_init_factor, norm=norm, scale_init=residual_weight_init_factor, norm=False,
weight_init_factor=residual_weight_init_factor) weight_init_factor=residual_weight_init_factor)
self.join_conv = block(nf, nf, norm=final_norm, bias=False, activation=True) self.join_conv = block(nf, nf, norm=final_norm, bias=False, activation=True)
def forward(self, x, ref): def forward(self, x, ref):
joined = torch.cat([x, ref], dim=1) joined = torch.cat([x, ref], dim=1)
branch = self.branch(joined) branch = self.branch(joined)
return self.join_conv(x + branch) return self.join_conv(x + branch), torch.std(branch)
# Basic convolutional upsampling block that uses interpolate. # Basic convolutional upsampling block that uses interpolate.