forked from mrq/DL-Art-School
Add DualOutputSRG
Also removes the old multi-return mechanism that Generators support. Also fixes AttentionNorm.
This commit is contained in:
parent
a2285ff2ee
commit
1b1431133b
|
@ -70,9 +70,11 @@ class SRGANModel(BaseModel):
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
|
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
|
||||||
self.l_fea_w = train_opt['feature_weight']
|
self.l_fea_w = train_opt['feature_weight']
|
||||||
self.l_fea_w_decay = train_opt['feature_weight_decay']
|
self.l_fea_w_decay_start = train_opt['feature_weight_decay_start']
|
||||||
self.l_fea_w_decay_steps = train_opt['feature_weight_decay_steps']
|
self.l_fea_w_decay_steps = train_opt['feature_weight_decay_steps']
|
||||||
self.l_fea_w_minimum = train_opt['feature_weight_minimum']
|
self.l_fea_w_minimum = train_opt['feature_weight_minimum']
|
||||||
|
if self.l_fea_w_decay_start:
|
||||||
|
self.l_fea_w_decay_step_size = (self.l_fea_w - self.l_fea_w_minimum) / (self.l_fea_w_decay_steps)
|
||||||
else:
|
else:
|
||||||
logger.info('Remove feature loss.')
|
logger.info('Remove feature loss.')
|
||||||
self.cri_fea = None
|
self.cri_fea = None
|
||||||
|
@ -202,16 +204,17 @@ class SRGANModel(BaseModel):
|
||||||
for p in self.netD.parameters():
|
for p in self.netD.parameters():
|
||||||
p.requires_grad = False
|
p.requires_grad = False
|
||||||
|
|
||||||
if step > self.D_init_iters:
|
if step >= self.D_init_iters:
|
||||||
self.optimizer_G.zero_grad()
|
self.optimizer_G.zero_grad()
|
||||||
|
|
||||||
self.swapout_D(step)
|
self.swapout_D(step)
|
||||||
self.swapout_G(step)
|
self.swapout_G(step)
|
||||||
|
|
||||||
# Turning off G-grad is required to enable mega-batching and D_update_ratio to work together for some reason.
|
# Turning off G-grad is required to enable mega-batching and D_update_ratio to work together for some reason.
|
||||||
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
if step % self.D_update_ratio == 0 and step >= self.D_init_iters:
|
||||||
for p in self.netG.parameters():
|
for p in self.netG.parameters():
|
||||||
p.requires_grad = True
|
if p.dtype != torch.int64 and p.dtype != torch.bool:
|
||||||
|
p.requires_grad = True
|
||||||
else:
|
else:
|
||||||
for p in self.netG.parameters():
|
for p in self.netG.parameters():
|
||||||
p.requires_grad = False
|
p.requires_grad = False
|
||||||
|
@ -227,35 +230,28 @@ class SRGANModel(BaseModel):
|
||||||
_t = time()
|
_t = time()
|
||||||
|
|
||||||
self.fake_GenOut = []
|
self.fake_GenOut = []
|
||||||
|
self.fea_GenOut = []
|
||||||
self.fake_H = []
|
self.fake_H = []
|
||||||
var_ref_skips = []
|
var_ref_skips = []
|
||||||
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
|
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
|
||||||
fake_GenOut = self.netG(var_L)
|
fea_GenOut, fake_GenOut = self.netG(var_L)
|
||||||
|
|
||||||
if _profile:
|
if _profile:
|
||||||
print("Gen forward %f" % (time() - _t,))
|
print("Gen forward %f" % (time() - _t,))
|
||||||
_t = time()
|
_t = time()
|
||||||
|
|
||||||
# Extract the image output. For generators that output skip-through connections, the master output is always
|
self.fake_GenOut.append(fake_GenOut.detach())
|
||||||
# the first element of the tuple.
|
self.fea_GenOut.append(fea_GenOut.detach())
|
||||||
if isinstance(fake_GenOut, tuple):
|
|
||||||
gen_img = fake_GenOut[0]
|
|
||||||
# The following line detaches all generator outputs that are not None.
|
|
||||||
self.fake_GenOut.append(tuple([(x.detach() if x is not None else None) for x in list(fake_GenOut)]))
|
|
||||||
var_ref = (var_ref,) # This is a tuple for legacy reasons.
|
|
||||||
else:
|
|
||||||
gen_img = fake_GenOut
|
|
||||||
self.fake_GenOut.append(fake_GenOut.detach())
|
|
||||||
|
|
||||||
l_g_total = 0
|
l_g_total = 0
|
||||||
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
if step % self.D_update_ratio == 0 and step >= self.D_init_iters:
|
||||||
if self.cri_pix: # pixel loss
|
if self.cri_pix: # pixel loss
|
||||||
l_g_pix = self.l_pix_w * self.cri_pix(gen_img, pix)
|
l_g_pix = self.l_pix_w * self.cri_pix(fea_GenOut, pix)
|
||||||
l_g_pix_log = l_g_pix / self.l_pix_w
|
l_g_pix_log = l_g_pix / self.l_pix_w
|
||||||
l_g_total += l_g_pix
|
l_g_total += l_g_pix
|
||||||
if self.cri_fea: # feature loss
|
if self.cri_fea: # feature loss
|
||||||
real_fea = self.netF(pix).detach()
|
real_fea = self.netF(pix).detach()
|
||||||
fake_fea = self.netF(gen_img)
|
fake_fea = self.netF(fea_GenOut)
|
||||||
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
|
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
|
||||||
l_g_fea_log = l_g_fea / self.l_fea_w
|
l_g_fea_log = l_g_fea / self.l_fea_w
|
||||||
l_g_total += l_g_fea
|
l_g_total += l_g_fea
|
||||||
|
@ -266,8 +262,13 @@ class SRGANModel(BaseModel):
|
||||||
|
|
||||||
# Decay the influence of the feature loss. As the model trains, the GAN will play a stronger role
|
# Decay the influence of the feature loss. As the model trains, the GAN will play a stronger role
|
||||||
# in the resultant image.
|
# in the resultant image.
|
||||||
if step % self.l_fea_w_decay_steps == 0:
|
if self.l_fea_w_decay_start and step > self.l_fea_w_decay_start:
|
||||||
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay)
|
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w - self.l_fea_w_decay_step_size * (step - self.l_fea_w_decay_start))
|
||||||
|
|
||||||
|
# Note to future self: The BCELoss(0, 1) and BCELoss(0, 0) = .6931
|
||||||
|
# Effectively this means that the generator has only completely "won" when l_d_real and l_d_fake is
|
||||||
|
# equal to this value. If I ever come up with an algorithm that tunes fea/gan weights automatically,
|
||||||
|
# it should target this value.
|
||||||
|
|
||||||
if self.l_gan_w > 0:
|
if self.l_gan_w > 0:
|
||||||
if self.opt['train']['gan_type'] == 'gan' or self.opt['train']['gan_type'] == 'pixgan':
|
if self.opt['train']['gan_type'] == 'gan' or self.opt['train']['gan_type'] == 'pixgan':
|
||||||
|
@ -304,7 +305,7 @@ class SRGANModel(BaseModel):
|
||||||
for p in self.netD.parameters():
|
for p in self.netD.parameters():
|
||||||
p.requires_grad = True
|
p.requires_grad = True
|
||||||
|
|
||||||
noise = torch.randn_like(var_ref[0]) * noise_theta
|
noise = torch.randn_like(var_ref) * noise_theta
|
||||||
noise.to(self.device)
|
noise.to(self.device)
|
||||||
self.optimizer_D.zero_grad()
|
self.optimizer_D.zero_grad()
|
||||||
real_disc_images = []
|
real_disc_images = []
|
||||||
|
@ -312,17 +313,17 @@ class SRGANModel(BaseModel):
|
||||||
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
|
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
|
||||||
# Re-compute generator outputs (post-update).
|
# Re-compute generator outputs (post-update).
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
fake_H = self.netG(var_L)
|
_, fake_H = self.netG(var_L)
|
||||||
# The following line detaches all generator outputs that are not None.
|
# The following line detaches all generator outputs that are not None.
|
||||||
fake_H = tuple([(x.detach() if x is not None else None) for x in list(fake_H)])
|
fake_H = fake_H.detach()
|
||||||
|
|
||||||
if _profile:
|
if _profile:
|
||||||
print("Gen forward for disc %f" % (time() - _t,))
|
print("Gen forward for disc %f" % (time() - _t,))
|
||||||
_t = time()
|
_t = time()
|
||||||
|
|
||||||
# Apply noise to the inputs to slow discriminator convergence.
|
# Apply noise to the inputs to slow discriminator convergence.
|
||||||
var_ref = (var_ref + noise,)
|
var_ref = var_ref + noise
|
||||||
fake_H = (fake_H[0] + noise,) + fake_H[1:]
|
fake_H = fake_H + noise
|
||||||
if self.opt['train']['gan_type'] == 'gan':
|
if self.opt['train']['gan_type'] == 'gan':
|
||||||
# need to forward and backward separately, since batch norm statistics differ
|
# need to forward and backward separately, since batch norm statistics differ
|
||||||
# real
|
# real
|
||||||
|
@ -340,10 +341,10 @@ class SRGANModel(BaseModel):
|
||||||
if self.opt['train']['gan_type'] == 'pixgan':
|
if self.opt['train']['gan_type'] == 'pixgan':
|
||||||
# randomly determine portions of the image to swap to keep the discriminator honest.
|
# randomly determine portions of the image to swap to keep the discriminator honest.
|
||||||
pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters()
|
pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters()
|
||||||
disc_output_shape = (var_ref[0].shape[0], pixdisc_channels, var_ref[0].shape[2] // pixdisc_output_reduction, var_ref[0].shape[3] // pixdisc_output_reduction)
|
disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction)
|
||||||
b, _, w, h = var_ref[0].shape
|
b, _, w, h = var_ref.shape
|
||||||
real = torch.ones((b, pixdisc_channels, w, h), device=var_ref[0].device)
|
real = torch.ones((b, pixdisc_channels, w, h), device=var_ref.device)
|
||||||
fake = torch.zeros((b, pixdisc_channels, w, h), device=var_ref[0].device)
|
fake = torch.zeros((b, pixdisc_channels, w, h), device=var_ref.device)
|
||||||
SWAP_MAX_DIM = w // 4
|
SWAP_MAX_DIM = w // 4
|
||||||
SWAP_MIN_DIM = 16
|
SWAP_MIN_DIM = 16
|
||||||
assert SWAP_MAX_DIM > 0
|
assert SWAP_MAX_DIM > 0
|
||||||
|
@ -360,9 +361,9 @@ class SRGANModel(BaseModel):
|
||||||
swap_w = w - swap_x
|
swap_w = w - swap_x
|
||||||
if swap_y + swap_h > h:
|
if swap_y + swap_h > h:
|
||||||
swap_h = h - swap_y
|
swap_h = h - swap_y
|
||||||
t = fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone()
|
t = fake_H[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone()
|
||||||
fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)]
|
fake_H[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)]
|
||||||
var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t
|
var_ref[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t
|
||||||
real[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 0.0
|
real[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 0.0
|
||||||
fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0
|
fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0
|
||||||
|
|
||||||
|
@ -422,8 +423,8 @@ class SRGANModel(BaseModel):
|
||||||
_t = time()
|
_t = time()
|
||||||
|
|
||||||
# Append var_ref here, so that we can inspect the alterations the disc made if pixgan
|
# Append var_ref here, so that we can inspect the alterations the disc made if pixgan
|
||||||
var_ref_skips.append(var_ref[0].detach())
|
var_ref_skips.append(var_ref.detach())
|
||||||
self.fake_H.append(fake_H[0].detach())
|
self.fake_H.append(fake_H.detach())
|
||||||
self.optimizer_D.step()
|
self.optimizer_D.step()
|
||||||
|
|
||||||
|
|
||||||
|
@ -436,32 +437,28 @@ class SRGANModel(BaseModel):
|
||||||
sample_save_path = os.path.join(self.opt['path']['models'], "..", "temp")
|
sample_save_path = os.path.join(self.opt['path']['models'], "..", "temp")
|
||||||
os.makedirs(os.path.join(sample_save_path, "hr"), exist_ok=True)
|
os.makedirs(os.path.join(sample_save_path, "hr"), exist_ok=True)
|
||||||
os.makedirs(os.path.join(sample_save_path, "lr"), exist_ok=True)
|
os.makedirs(os.path.join(sample_save_path, "lr"), exist_ok=True)
|
||||||
|
os.makedirs(os.path.join(sample_save_path, "gen_fea"), exist_ok=True)
|
||||||
os.makedirs(os.path.join(sample_save_path, "gen"), exist_ok=True)
|
os.makedirs(os.path.join(sample_save_path, "gen"), exist_ok=True)
|
||||||
os.makedirs(os.path.join(sample_save_path, "disc_fake"), exist_ok=True)
|
os.makedirs(os.path.join(sample_save_path, "disc_fake"), exist_ok=True)
|
||||||
os.makedirs(os.path.join(sample_save_path, "pix"), exist_ok=True)
|
os.makedirs(os.path.join(sample_save_path, "pix"), exist_ok=True)
|
||||||
os.makedirs(os.path.join(sample_save_path, "disc"), exist_ok=True)
|
os.makedirs(os.path.join(sample_save_path, "disc"), exist_ok=True)
|
||||||
multi_gen = False
|
os.makedirs(os.path.join(sample_save_path, "ref"), exist_ok=True)
|
||||||
if isinstance(self.fake_GenOut[0], tuple):
|
|
||||||
os.makedirs(os.path.join(sample_save_path, "ref"), exist_ok=True)
|
|
||||||
multi_gen = True
|
|
||||||
|
|
||||||
# fed_LQ is not chunked.
|
# fed_LQ is not chunked.
|
||||||
for i in range(self.mega_batch_factor):
|
for i in range(self.mega_batch_factor):
|
||||||
utils.save_image(self.var_H[i].cpu(), os.path.join(sample_save_path, "hr", "%05i_%02i.png" % (step, i)))
|
utils.save_image(self.var_H[i].cpu(), os.path.join(sample_save_path, "hr", "%05i_%02i.png" % (step, i)))
|
||||||
utils.save_image(self.var_L[i].cpu(), os.path.join(sample_save_path, "lr", "%05i_%02i.png" % (step, i)))
|
utils.save_image(self.var_L[i].cpu(), os.path.join(sample_save_path, "lr", "%05i_%02i.png" % (step, i)))
|
||||||
utils.save_image(self.pix[i].cpu(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i)))
|
utils.save_image(self.pix[i].cpu(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i)))
|
||||||
if multi_gen:
|
utils.save_image(self.fake_GenOut[i].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
|
||||||
utils.save_image(self.fake_GenOut[i][0].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
|
utils.save_image(self.fea_GenOut[i].cpu(), os.path.join(sample_save_path, "gen_fea", "%05i_%02i.png" % (step, i)))
|
||||||
if self.l_gan_w > 0 and step > self.G_warmup and self.opt['train']['gan_type'] == 'pixgan':
|
if self.l_gan_w > 0 and step > self.G_warmup and self.opt['train']['gan_type'] == 'pixgan':
|
||||||
utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i)))
|
utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i)))
|
||||||
utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "fake%05i_%02i.png" % (step, i)))
|
utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "fake%05i_%02i.png" % (step, i)))
|
||||||
utils.save_image(F.interpolate(fake_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "fake%05i_%02i.png" % (step, i)))
|
utils.save_image(F.interpolate(fake_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "fake%05i_%02i.png" % (step, i)))
|
||||||
utils.save_image(F.interpolate(real_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "real%05i_%02i.png" % (step, i)))
|
utils.save_image(F.interpolate(real_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "real%05i_%02i.png" % (step, i)))
|
||||||
else:
|
|
||||||
utils.save_image(self.fake_GenOut[i].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
|
|
||||||
|
|
||||||
# Log metrics
|
# Log metrics
|
||||||
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
if step % self.D_update_ratio == 0 and step >= self.D_init_iters:
|
||||||
if self.cri_pix:
|
if self.cri_pix:
|
||||||
self.add_log_entry('l_g_pix', l_g_pix_log.item())
|
self.add_log_entry('l_g_pix', l_g_pix_log.item())
|
||||||
if self.cri_fea:
|
if self.cri_fea:
|
||||||
|
|
|
@ -196,7 +196,8 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
|
||||||
if self.upsample_factor > 2:
|
if self.upsample_factor > 2:
|
||||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||||
x = self.upconv2(x)
|
x = self.upconv2(x)
|
||||||
return self.final_conv(self.hr_conv(x)),
|
x = self.final_conv(self.hr_conv(x))
|
||||||
|
return x, x
|
||||||
|
|
||||||
def set_temperature(self, temp):
|
def set_temperature(self, temp):
|
||||||
[sw.set_temperature(temp) for sw in self.switches]
|
[sw.set_temperature(temp) for sw in self.switches]
|
||||||
|
@ -319,3 +320,105 @@ class ConfigurableSwitchedResidualGenerator3(nn.Module):
|
||||||
val["switch_%i_specificity" % (i,)] = means[i]
|
val["switch_%i_specificity" % (i,)] = means[i]
|
||||||
val["switch_%i_histogram" % (i,)] = hists[i]
|
val["switch_%i_histogram" % (i,)] = hists[i]
|
||||||
return val
|
return val
|
||||||
|
|
||||||
|
|
||||||
|
class DualOutputSRG(nn.Module):
|
||||||
|
def __init__(self, switch_depth, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
|
||||||
|
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
|
||||||
|
heightened_final_step=50000, upsample_factor=1,
|
||||||
|
add_scalable_noise_to_transforms=False):
|
||||||
|
super(DualOutputSRG, self).__init__()
|
||||||
|
switches = []
|
||||||
|
self.initial_conv = ConvBnLelu(3, transformation_filters, norm=False, activation=False, bias=True)
|
||||||
|
|
||||||
|
self.fea_upconv1 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
|
||||||
|
self.fea_upconv2 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
|
||||||
|
self.fea_hr_conv = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
|
||||||
|
self.fea_final_conv = ConvBnLelu(transformation_filters, 3, norm=False, activation=False, bias=True)
|
||||||
|
|
||||||
|
self.upconv1 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
|
||||||
|
self.upconv2 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
|
||||||
|
self.hr_conv = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
|
||||||
|
self.final_conv = ConvBnLelu(transformation_filters, 3, norm=False, activation=False, bias=True)
|
||||||
|
|
||||||
|
for _ in range(switch_depth):
|
||||||
|
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions, switch_processing_layers, trans_counts)
|
||||||
|
pretransform_fn = functools.partial(ConvBnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1)
|
||||||
|
transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), transformation_filters, kernel_size=trans_kernel_sizes, depth=trans_layers, weight_init_factor=.1)
|
||||||
|
switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||||
|
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||||
|
transform_count=trans_counts, init_temp=initial_temp,
|
||||||
|
add_scalable_noise_to_transforms=add_scalable_noise_to_transforms))
|
||||||
|
|
||||||
|
self.switches = nn.ModuleList(switches)
|
||||||
|
self.transformation_counts = trans_counts
|
||||||
|
self.init_temperature = initial_temp
|
||||||
|
self.final_temperature_step = final_temperature_step
|
||||||
|
self.heightened_temp_min = heightened_temp_min
|
||||||
|
self.heightened_final_step = heightened_final_step
|
||||||
|
self.attentions = None
|
||||||
|
self.upsample_factor = upsample_factor
|
||||||
|
assert self.upsample_factor == 2 or self.upsample_factor == 4
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.initial_conv(x)
|
||||||
|
|
||||||
|
self.attentions = []
|
||||||
|
for i, sw in enumerate(self.switches):
|
||||||
|
x, att = sw.forward(x, True)
|
||||||
|
self.attentions.append(att)
|
||||||
|
|
||||||
|
if i == len(self.switches)-2:
|
||||||
|
fea = self.fea_upconv1(F.interpolate(x, scale_factor=2, mode="nearest"))
|
||||||
|
if self.upsample_factor > 2:
|
||||||
|
fea = F.interpolate(fea, scale_factor=2, mode="nearest")
|
||||||
|
fea = self.fea_upconv2(fea)
|
||||||
|
fea = self.fea_final_conv(self.hr_conv(fea))
|
||||||
|
|
||||||
|
x = self.upconv1(F.interpolate(x, scale_factor=2, mode="nearest"))
|
||||||
|
if self.upsample_factor > 2:
|
||||||
|
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||||
|
x = self.upconv2(x)
|
||||||
|
return fea, self.final_conv(self.hr_conv(x))
|
||||||
|
|
||||||
|
def set_temperature(self, temp):
|
||||||
|
[sw.set_temperature(temp) for sw in self.switches]
|
||||||
|
|
||||||
|
def update_for_step(self, step, experiments_path='.'):
|
||||||
|
if self.attentions:
|
||||||
|
temp = max(1, int(self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step))
|
||||||
|
if temp == 1 and self.heightened_final_step and self.heightened_final_step != 1:
|
||||||
|
# Once the temperature passes (1) it enters an inverted curve to match the linear curve from above.
|
||||||
|
# without this, the attention specificity "spikes" incredibly fast in the last few iterations.
|
||||||
|
h_steps_total = self.heightened_final_step - self.final_temperature_step
|
||||||
|
h_steps_current = max(min(step - self.final_temperature_step, h_steps_total), 1)
|
||||||
|
# The "gap" will represent the steps that need to be traveled as a linear function.
|
||||||
|
h_gap = 1 / self.heightened_temp_min
|
||||||
|
temp = h_gap * h_steps_current / h_steps_total
|
||||||
|
# Invert temperature to represent reality on this side of the curve
|
||||||
|
temp = 1 / temp
|
||||||
|
self.set_temperature(temp)
|
||||||
|
if step % 50 == 0:
|
||||||
|
[save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts, step, "a%i" % (i+1,)) for i in range(len(self.switches))]
|
||||||
|
|
||||||
|
def get_debug_values(self, step):
|
||||||
|
temp = self.switches[0].switch.temperature
|
||||||
|
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
|
||||||
|
means = [i[0] for i in mean_hists]
|
||||||
|
hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists]
|
||||||
|
val = {"switch_temperature": temp}
|
||||||
|
for i in range(len(means)):
|
||||||
|
val["switch_%i_specificity" % (i,)] = means[i]
|
||||||
|
val["switch_%i_histogram" % (i,)] = hists[i]
|
||||||
|
return val
|
||||||
|
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict, strict=True):
|
||||||
|
# Support backwards compatibility where accumulator_index and accumulator_filled are not in this state_dict
|
||||||
|
t_state = self.state_dict()
|
||||||
|
if 'switches.0.switch.attention_norm.accumulator_index' not in state_dict.keys():
|
||||||
|
for i in range(4):
|
||||||
|
state_dict['switches.%i.switch.attention_norm.accumulator' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator' % (i,)]
|
||||||
|
state_dict['switches.%i.switch.attention_norm.accumulator_index' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_index' % (i,)]
|
||||||
|
state_dict['switches.%i.switch.attention_norm.accumulator_filled' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_filled' % (i,)]
|
||||||
|
super(DualOutputSRG, self).load_state_dict(state_dict, strict)
|
|
@ -51,7 +51,6 @@ class Discriminator_VGG_128(nn.Module):
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = x[0]
|
|
||||||
fea = self.lrelu(self.conv0_0(x))
|
fea = self.lrelu(self.conv0_0(x))
|
||||||
fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
|
fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
|
||||||
|
|
||||||
|
@ -127,7 +126,6 @@ class Discriminator_VGG_PixLoss(nn.Module):
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
|
|
||||||
def forward(self, x, flatten=True):
|
def forward(self, x, flatten=True):
|
||||||
x = x[0]
|
|
||||||
fea0 = self.lrelu(self.conv0_0(x))
|
fea0 = self.lrelu(self.conv0_0(x))
|
||||||
fea0 = self.lrelu(self.bn0_1(self.conv0_1(fea0)))
|
fea0 = self.lrelu(self.bn0_1(self.conv0_1(fea0)))
|
||||||
|
|
||||||
|
@ -205,7 +203,6 @@ class Discriminator_UNet(nn.Module):
|
||||||
self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False)
|
self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False)
|
||||||
|
|
||||||
def forward(self, x, flatten=True):
|
def forward(self, x, flatten=True):
|
||||||
x = x[0]
|
|
||||||
fea0 = self.conv0_0(x)
|
fea0 = self.conv0_0(x)
|
||||||
fea0 = self.conv0_1(fea0)
|
fea0 = self.conv0_1(fea0)
|
||||||
|
|
||||||
|
|
|
@ -78,6 +78,15 @@ def define_G(opt, net_key='network_G'):
|
||||||
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
||||||
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
||||||
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
||||||
|
elif which_model == "DualOutputSRG":
|
||||||
|
netG = SwitchedGen_arch.DualOutputSRG(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'],
|
||||||
|
switch_reductions=opt_net['switch_reductions'],
|
||||||
|
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
|
||||||
|
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
|
||||||
|
transformation_filters=opt_net['transformation_filters'],
|
||||||
|
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
||||||
|
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
||||||
|
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
||||||
|
|
||||||
# image corruption
|
# image corruption
|
||||||
elif which_model == 'HighToLowResNet':
|
elif which_model == 'HighToLowResNet':
|
||||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
||||||
def main():
|
def main():
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg2.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_dual_srg.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
|
|
@ -93,7 +93,7 @@ if __name__ == "__main__":
|
||||||
torch.randn(1, 3, 64, 64),
|
torch.randn(1, 3, 64, 64),
|
||||||
device='cuda')
|
device='cuda')
|
||||||
'''
|
'''
|
||||||
test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator2,
|
test_stability(functools.partial(srg.DualOutputSRG,
|
||||||
switch_depth=4,
|
switch_depth=4,
|
||||||
switch_filters=64,
|
switch_filters=64,
|
||||||
switch_reductions=4,
|
switch_reductions=4,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user