forked from mrq/DL-Art-School
Report the standard deviation of ref branches
This patch also ups the contribution
This commit is contained in:
parent
668bfbff6d
commit
9e5aa166de
|
@ -361,7 +361,7 @@ class RefJoiner(nn.Module):
|
||||||
super(RefJoiner, self).__init__()
|
super(RefJoiner, self).__init__()
|
||||||
self.lin1 = nn.Linear(512, 256)
|
self.lin1 = nn.Linear(512, 256)
|
||||||
self.lin2 = nn.Linear(256, nf)
|
self.lin2 = nn.Linear(256, nf)
|
||||||
self.join = ReferenceJoinBlock(nf, residual_weight_init_factor=.05, norm=False)
|
self.join = ReferenceJoinBlock(nf, residual_weight_init_factor=.1)
|
||||||
|
|
||||||
def forward(self, x, ref):
|
def forward(self, x, ref):
|
||||||
ref = self.lin1(ref)
|
ref = self.lin1(ref)
|
||||||
|
@ -374,11 +374,11 @@ class RefJoiner(nn.Module):
|
||||||
class ModuleWithRef(nn.Module):
|
class ModuleWithRef(nn.Module):
|
||||||
def __init__(self, nf, mcnv, *args):
|
def __init__(self, nf, mcnv, *args):
|
||||||
super(ModuleWithRef, self).__init__()
|
super(ModuleWithRef, self).__init__()
|
||||||
self.join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False)
|
self.join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2)
|
||||||
self.multi = mcnv(*args)
|
self.multi = mcnv(*args)
|
||||||
|
|
||||||
def forward(self, x, ref):
|
def forward(self, x, ref):
|
||||||
out = self.join(x, ref)
|
out, _ = self.join(x, ref)
|
||||||
return self.multi(out)
|
return self.multi(out)
|
||||||
|
|
||||||
|
|
||||||
|
@ -402,7 +402,7 @@ class SwitchedSpsrWithRef2(nn.Module):
|
||||||
|
|
||||||
# Feature branch
|
# Feature branch
|
||||||
self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
|
self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
|
||||||
self.noise_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.01, norm=False)
|
self.noise_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.1)
|
||||||
self.ref_join1 = RefJoiner(nf)
|
self.ref_join1 = RefJoiner(nf)
|
||||||
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||||
|
@ -421,9 +421,9 @@ class SwitchedSpsrWithRef2(nn.Module):
|
||||||
# Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague.
|
# Grad branch. Note - groupnorm on this branch is REALLY bad. Avoid it like the plague.
|
||||||
self.get_g_nopadding = ImageGradientNoPadding()
|
self.get_g_nopadding = ImageGradientNoPadding()
|
||||||
self.grad_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False)
|
self.grad_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False)
|
||||||
self.noise_ref_join_grad = ReferenceJoinBlock(nf, residual_weight_init_factor=.01, norm=False)
|
self.noise_ref_join_grad = ReferenceJoinBlock(nf, residual_weight_init_factor=.1)
|
||||||
self.ref_join3 = RefJoiner(nf)
|
self.ref_join3 = RefJoiner(nf)
|
||||||
self.grad_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False, final_norm=False)
|
self.grad_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3, final_norm=False)
|
||||||
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
self.sw_grad = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||||
attention_norm=True,
|
attention_norm=True,
|
||||||
|
@ -436,8 +436,8 @@ class SwitchedSpsrWithRef2(nn.Module):
|
||||||
|
|
||||||
# Join branch (grad+fea)
|
# Join branch (grad+fea)
|
||||||
self.ref_join4 = RefJoiner(nf)
|
self.ref_join4 = RefJoiner(nf)
|
||||||
self.noise_ref_join_conjoin = ReferenceJoinBlock(nf, residual_weight_init_factor=.01, norm=False)
|
self.noise_ref_join_conjoin = ReferenceJoinBlock(nf, residual_weight_init_factor=.1)
|
||||||
self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False)
|
self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.3)
|
||||||
self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||||
attention_norm=True,
|
attention_norm=True,
|
||||||
|
@ -453,42 +453,55 @@ class SwitchedSpsrWithRef2(nn.Module):
|
||||||
self.final_temperature_step = 10000
|
self.final_temperature_step = 10000
|
||||||
|
|
||||||
def forward(self, x, ref, center_coord):
|
def forward(self, x, ref, center_coord):
|
||||||
|
ref_stds = []
|
||||||
|
noise_stds = []
|
||||||
|
|
||||||
x_grad = self.get_g_nopadding(x)
|
x_grad = self.get_g_nopadding(x)
|
||||||
ref = self.reference_processor(ref, center_coord)
|
ref = self.reference_processor(ref, center_coord)
|
||||||
|
|
||||||
x = self.model_fea_conv(x)
|
x = self.model_fea_conv(x)
|
||||||
x1 = x
|
x1 = x
|
||||||
x1 = self.ref_join1(x1, ref)
|
x1, rstd = self.ref_join1(x1, ref)
|
||||||
x1, a1 = self.sw1(x1, True, identity=x)
|
x1, a1 = self.sw1(x1, True, identity=x)
|
||||||
|
ref_stds.append(rstd)
|
||||||
|
|
||||||
x2 = x1
|
x2 = x1
|
||||||
x2 = self.noise_ref_join(x2, torch.randn_like(x2))
|
x2, nstd = self.noise_ref_join(x2, torch.randn_like(x2))
|
||||||
x2 = self.ref_join2(x2, ref)
|
x2, rstd = self.ref_join2(x2, ref)
|
||||||
x2, a2 = self.sw2(x2, True, identity=x1)
|
x2, a2 = self.sw2(x2, True, identity=x1)
|
||||||
|
noise_stds.append(nstd)
|
||||||
|
ref_stds.append(rstd)
|
||||||
|
|
||||||
x_grad = self.grad_conv(x_grad)
|
x_grad = self.grad_conv(x_grad)
|
||||||
x_grad_identity = x_grad
|
x_grad_identity = x_grad
|
||||||
x_grad = self.noise_ref_join_grad(x_grad, torch.randn_like(x_grad))
|
x_grad, nstd = self.noise_ref_join_grad(x_grad, torch.randn_like(x_grad))
|
||||||
x_grad = self.ref_join3(x_grad, ref)
|
x_grad, rstd = self.ref_join3(x_grad, ref)
|
||||||
x_grad = self.grad_ref_join(x_grad, x1)
|
x_grad, grad_fea_std = self.grad_ref_join(x_grad, x1)
|
||||||
x_grad, a3 = self.sw_grad(x_grad, True, identity=x_grad_identity)
|
x_grad, a3 = self.sw_grad(x_grad, True, identity=x_grad_identity)
|
||||||
x_grad = self.grad_lr_conv(x_grad)
|
x_grad = self.grad_lr_conv(x_grad)
|
||||||
x_grad = self.grad_lr_conv2(x_grad)
|
x_grad = self.grad_lr_conv2(x_grad)
|
||||||
x_grad_out = self.upsample_grad(x_grad)
|
x_grad_out = self.upsample_grad(x_grad)
|
||||||
x_grad_out = self.grad_branch_output_conv(x_grad_out)
|
x_grad_out = self.grad_branch_output_conv(x_grad_out)
|
||||||
|
noise_stds.append(nstd)
|
||||||
|
ref_stds.append(rstd)
|
||||||
|
|
||||||
x_out = x2
|
x_out = x2
|
||||||
x_out = self.noise_ref_join_conjoin(x_out, torch.randn_like(x_out))
|
x_out, nstd = self.noise_ref_join_conjoin(x_out, torch.randn_like(x_out))
|
||||||
x_out = self.ref_join4(x_out, ref)
|
x_out, rstd = self.ref_join4(x_out, ref)
|
||||||
x_out = self.conjoin_ref_join(x_out, x_grad)
|
x_out, fea_grad_std = self.conjoin_ref_join(x_out, x_grad)
|
||||||
x_out, a4 = self.conjoin_sw(x_out, True, identity=x2)
|
x_out, a4 = self.conjoin_sw(x_out, True, identity=x2)
|
||||||
x_out = self.final_lr_conv(x_out)
|
x_out = self.final_lr_conv(x_out)
|
||||||
x_out = self.upsample(x_out)
|
x_out = self.upsample(x_out)
|
||||||
x_out = self.final_hr_conv1(x_out)
|
x_out = self.final_hr_conv1(x_out)
|
||||||
x_out = self.final_hr_conv2(x_out)
|
x_out = self.final_hr_conv2(x_out)
|
||||||
|
noise_stds.append(nstd)
|
||||||
|
ref_stds.append(rstd)
|
||||||
|
|
||||||
self.attentions = [a1, a2, a3, a4]
|
self.attentions = [a1, a2, a3, a4]
|
||||||
|
self.noise_stds = torch.stack(noise_stds).mean().detach().cpu()
|
||||||
|
self.ref_stds = torch.stack(ref_stds).mean().detach().cpu()
|
||||||
|
self.grad_fea_std = grad_fea_std.detach().cpu()
|
||||||
|
self.fea_grad_std = fea_grad_std.detach().cpu()
|
||||||
return x_grad_out, x_out, x_grad
|
return x_grad_out, x_out, x_grad
|
||||||
|
|
||||||
def set_temperature(self, temp):
|
def set_temperature(self, temp):
|
||||||
|
@ -509,7 +522,11 @@ class SwitchedSpsrWithRef2(nn.Module):
|
||||||
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
|
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
|
||||||
means = [i[0] for i in mean_hists]
|
means = [i[0] for i in mean_hists]
|
||||||
hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists]
|
hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists]
|
||||||
val = {"switch_temperature": temp}
|
val = {"switch_temperature": temp,
|
||||||
|
"reference_branch_std_dev": self.ref_stds,
|
||||||
|
"noise_branch_std_dev": self.noise_stds,
|
||||||
|
"grad_branch_feat_intg_std_dev": self.grad_fea_std,
|
||||||
|
"conjoin_branch_grad_intg_std_dev": self.fea_grad_std}
|
||||||
for i in range(len(means)):
|
for i in range(len(means)):
|
||||||
val["switch_%i_specificity" % (i,)] = means[i]
|
val["switch_%i_specificity" % (i,)] = means[i]
|
||||||
val["switch_%i_histogram" % (i,)] = hists[i]
|
val["switch_%i_histogram" % (i,)] = hists[i]
|
||||||
|
|
|
@ -456,17 +456,17 @@ class ConjoinBlock(nn.Module):
|
||||||
|
|
||||||
# Designed explicitly to join a mainline trunk with reference data. Implemented as a residual branch.
|
# Designed explicitly to join a mainline trunk with reference data. Implemented as a residual branch.
|
||||||
class ReferenceJoinBlock(nn.Module):
|
class ReferenceJoinBlock(nn.Module):
|
||||||
def __init__(self, nf, residual_weight_init_factor=1, norm=False, block=ConvGnLelu, final_norm=False):
|
def __init__(self, nf, residual_weight_init_factor=1, block=ConvGnLelu, final_norm=False):
|
||||||
super(ReferenceJoinBlock, self).__init__()
|
super(ReferenceJoinBlock, self).__init__()
|
||||||
self.branch = MultiConvBlock(nf * 2, nf + nf // 2, nf, kernel_size=3, depth=3,
|
self.branch = MultiConvBlock(nf * 2, nf + nf // 2, nf, kernel_size=3, depth=3,
|
||||||
scale_init=residual_weight_init_factor, norm=norm,
|
scale_init=residual_weight_init_factor, norm=False,
|
||||||
weight_init_factor=residual_weight_init_factor)
|
weight_init_factor=residual_weight_init_factor)
|
||||||
self.join_conv = block(nf, nf, norm=final_norm, bias=False, activation=True)
|
self.join_conv = block(nf, nf, norm=final_norm, bias=False, activation=True)
|
||||||
|
|
||||||
def forward(self, x, ref):
|
def forward(self, x, ref):
|
||||||
joined = torch.cat([x, ref], dim=1)
|
joined = torch.cat([x, ref], dim=1)
|
||||||
branch = self.branch(joined)
|
branch = self.branch(joined)
|
||||||
return self.join_conv(x + branch)
|
return self.join_conv(x + branch), torch.std(branch)
|
||||||
|
|
||||||
|
|
||||||
# Basic convolutional upsampling block that uses interpolate.
|
# Basic convolutional upsampling block that uses interpolate.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user