Add DualOutputSRG

Also removes the old multi-return mechanism that Generators support.
Also fixes AttentionNorm.
This commit is contained in:
James Betker 2020-07-14 09:28:24 -06:00
parent a2285ff2ee
commit 1b1431133b
6 changed files with 161 additions and 55 deletions

View File

@ -70,9 +70,11 @@ class SRGANModel(BaseModel):
else:
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
self.l_fea_w = train_opt['feature_weight']
self.l_fea_w_decay = train_opt['feature_weight_decay']
self.l_fea_w_decay_start = train_opt['feature_weight_decay_start']
self.l_fea_w_decay_steps = train_opt['feature_weight_decay_steps']
self.l_fea_w_minimum = train_opt['feature_weight_minimum']
if self.l_fea_w_decay_start:
self.l_fea_w_decay_step_size = (self.l_fea_w - self.l_fea_w_minimum) / (self.l_fea_w_decay_steps)
else:
logger.info('Remove feature loss.')
self.cri_fea = None
@ -202,16 +204,17 @@ class SRGANModel(BaseModel):
for p in self.netD.parameters():
p.requires_grad = False
if step > self.D_init_iters:
if step >= self.D_init_iters:
self.optimizer_G.zero_grad()
self.swapout_D(step)
self.swapout_G(step)
# Turning off G-grad is required to enable mega-batching and D_update_ratio to work together for some reason.
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
if step % self.D_update_ratio == 0 and step >= self.D_init_iters:
for p in self.netG.parameters():
p.requires_grad = True
if p.dtype != torch.int64 and p.dtype != torch.bool:
p.requires_grad = True
else:
for p in self.netG.parameters():
p.requires_grad = False
@ -227,35 +230,28 @@ class SRGANModel(BaseModel):
_t = time()
self.fake_GenOut = []
self.fea_GenOut = []
self.fake_H = []
var_ref_skips = []
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
fake_GenOut = self.netG(var_L)
fea_GenOut, fake_GenOut = self.netG(var_L)
if _profile:
print("Gen forward %f" % (time() - _t,))
_t = time()
# Extract the image output. For generators that output skip-through connections, the master output is always
# the first element of the tuple.
if isinstance(fake_GenOut, tuple):
gen_img = fake_GenOut[0]
# The following line detaches all generator outputs that are not None.
self.fake_GenOut.append(tuple([(x.detach() if x is not None else None) for x in list(fake_GenOut)]))
var_ref = (var_ref,) # This is a tuple for legacy reasons.
else:
gen_img = fake_GenOut
self.fake_GenOut.append(fake_GenOut.detach())
self.fake_GenOut.append(fake_GenOut.detach())
self.fea_GenOut.append(fea_GenOut.detach())
l_g_total = 0
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
if step % self.D_update_ratio == 0 and step >= self.D_init_iters:
if self.cri_pix: # pixel loss
l_g_pix = self.l_pix_w * self.cri_pix(gen_img, pix)
l_g_pix = self.l_pix_w * self.cri_pix(fea_GenOut, pix)
l_g_pix_log = l_g_pix / self.l_pix_w
l_g_total += l_g_pix
if self.cri_fea: # feature loss
real_fea = self.netF(pix).detach()
fake_fea = self.netF(gen_img)
fake_fea = self.netF(fea_GenOut)
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
l_g_fea_log = l_g_fea / self.l_fea_w
l_g_total += l_g_fea
@ -266,8 +262,13 @@ class SRGANModel(BaseModel):
# Decay the influence of the feature loss. As the model trains, the GAN will play a stronger role
# in the resultant image.
if step % self.l_fea_w_decay_steps == 0:
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay)
if self.l_fea_w_decay_start and step > self.l_fea_w_decay_start:
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w - self.l_fea_w_decay_step_size * (step - self.l_fea_w_decay_start))
# Note to future self: The BCELoss(0, 1) and BCELoss(0, 0) = .6931
# Effectively this means that the generator has only completely "won" when l_d_real and l_d_fake is
# equal to this value. If I ever come up with an algorithm that tunes fea/gan weights automatically,
# it should target this value.
if self.l_gan_w > 0:
if self.opt['train']['gan_type'] == 'gan' or self.opt['train']['gan_type'] == 'pixgan':
@ -304,7 +305,7 @@ class SRGANModel(BaseModel):
for p in self.netD.parameters():
p.requires_grad = True
noise = torch.randn_like(var_ref[0]) * noise_theta
noise = torch.randn_like(var_ref) * noise_theta
noise.to(self.device)
self.optimizer_D.zero_grad()
real_disc_images = []
@ -312,17 +313,17 @@ class SRGANModel(BaseModel):
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
# Re-compute generator outputs (post-update).
with torch.no_grad():
fake_H = self.netG(var_L)
_, fake_H = self.netG(var_L)
# The following line detaches all generator outputs that are not None.
fake_H = tuple([(x.detach() if x is not None else None) for x in list(fake_H)])
fake_H = fake_H.detach()
if _profile:
print("Gen forward for disc %f" % (time() - _t,))
_t = time()
# Apply noise to the inputs to slow discriminator convergence.
var_ref = (var_ref + noise,)
fake_H = (fake_H[0] + noise,) + fake_H[1:]
var_ref = var_ref + noise
fake_H = fake_H + noise
if self.opt['train']['gan_type'] == 'gan':
# need to forward and backward separately, since batch norm statistics differ
# real
@ -340,10 +341,10 @@ class SRGANModel(BaseModel):
if self.opt['train']['gan_type'] == 'pixgan':
# randomly determine portions of the image to swap to keep the discriminator honest.
pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters()
disc_output_shape = (var_ref[0].shape[0], pixdisc_channels, var_ref[0].shape[2] // pixdisc_output_reduction, var_ref[0].shape[3] // pixdisc_output_reduction)
b, _, w, h = var_ref[0].shape
real = torch.ones((b, pixdisc_channels, w, h), device=var_ref[0].device)
fake = torch.zeros((b, pixdisc_channels, w, h), device=var_ref[0].device)
disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction)
b, _, w, h = var_ref.shape
real = torch.ones((b, pixdisc_channels, w, h), device=var_ref.device)
fake = torch.zeros((b, pixdisc_channels, w, h), device=var_ref.device)
SWAP_MAX_DIM = w // 4
SWAP_MIN_DIM = 16
assert SWAP_MAX_DIM > 0
@ -360,9 +361,9 @@ class SRGANModel(BaseModel):
swap_w = w - swap_x
if swap_y + swap_h > h:
swap_h = h - swap_y
t = fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone()
fake_H[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)]
var_ref[0][:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t
t = fake_H[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone()
fake_H[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)]
var_ref[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t
real[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 0.0
fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0
@ -422,8 +423,8 @@ class SRGANModel(BaseModel):
_t = time()
# Append var_ref here, so that we can inspect the alterations the disc made if pixgan
var_ref_skips.append(var_ref[0].detach())
self.fake_H.append(fake_H[0].detach())
var_ref_skips.append(var_ref.detach())
self.fake_H.append(fake_H.detach())
self.optimizer_D.step()
@ -436,32 +437,28 @@ class SRGANModel(BaseModel):
sample_save_path = os.path.join(self.opt['path']['models'], "..", "temp")
os.makedirs(os.path.join(sample_save_path, "hr"), exist_ok=True)
os.makedirs(os.path.join(sample_save_path, "lr"), exist_ok=True)
os.makedirs(os.path.join(sample_save_path, "gen_fea"), exist_ok=True)
os.makedirs(os.path.join(sample_save_path, "gen"), exist_ok=True)
os.makedirs(os.path.join(sample_save_path, "disc_fake"), exist_ok=True)
os.makedirs(os.path.join(sample_save_path, "pix"), exist_ok=True)
os.makedirs(os.path.join(sample_save_path, "disc"), exist_ok=True)
multi_gen = False
if isinstance(self.fake_GenOut[0], tuple):
os.makedirs(os.path.join(sample_save_path, "ref"), exist_ok=True)
multi_gen = True
os.makedirs(os.path.join(sample_save_path, "ref"), exist_ok=True)
# fed_LQ is not chunked.
for i in range(self.mega_batch_factor):
utils.save_image(self.var_H[i].cpu(), os.path.join(sample_save_path, "hr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.var_L[i].cpu(), os.path.join(sample_save_path, "lr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.pix[i].cpu(), os.path.join(sample_save_path, "pix", "%05i_%02i.png" % (step, i)))
if multi_gen:
utils.save_image(self.fake_GenOut[i][0].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
if self.l_gan_w > 0 and step > self.G_warmup and self.opt['train']['gan_type'] == 'pixgan':
utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i)))
utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "fake%05i_%02i.png" % (step, i)))
utils.save_image(F.interpolate(fake_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "fake%05i_%02i.png" % (step, i)))
utils.save_image(F.interpolate(real_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "real%05i_%02i.png" % (step, i)))
else:
utils.save_image(self.fake_GenOut[i].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
utils.save_image(self.fake_GenOut[i].cpu(), os.path.join(sample_save_path, "gen", "%05i_%02i.png" % (step, i)))
utils.save_image(self.fea_GenOut[i].cpu(), os.path.join(sample_save_path, "gen_fea", "%05i_%02i.png" % (step, i)))
if self.l_gan_w > 0 and step > self.G_warmup and self.opt['train']['gan_type'] == 'pixgan':
utils.save_image(var_ref_skips[i].cpu(), os.path.join(sample_save_path, "ref", "%05i_%02i.png" % (step, i)))
utils.save_image(self.fake_H[i], os.path.join(sample_save_path, "disc_fake", "fake%05i_%02i.png" % (step, i)))
utils.save_image(F.interpolate(fake_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "fake%05i_%02i.png" % (step, i)))
utils.save_image(F.interpolate(real_disc_images[i], scale_factor=4), os.path.join(sample_save_path, "disc", "real%05i_%02i.png" % (step, i)))
# Log metrics
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
if step % self.D_update_ratio == 0 and step >= self.D_init_iters:
if self.cri_pix:
self.add_log_entry('l_g_pix', l_g_pix_log.item())
if self.cri_fea:

View File

@ -196,7 +196,8 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module):
if self.upsample_factor > 2:
x = F.interpolate(x, scale_factor=2, mode="nearest")
x = self.upconv2(x)
return self.final_conv(self.hr_conv(x)),
x = self.final_conv(self.hr_conv(x))
return x, x
def set_temperature(self, temp):
[sw.set_temperature(temp) for sw in self.switches]
@ -318,4 +319,106 @@ class ConfigurableSwitchedResidualGenerator3(nn.Module):
for i in range(len(means)):
val["switch_%i_specificity" % (i,)] = means[i]
val["switch_%i_histogram" % (i,)] = hists[i]
return val
return val
class DualOutputSRG(nn.Module):
def __init__(self, switch_depth, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes,
trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1,
heightened_final_step=50000, upsample_factor=1,
add_scalable_noise_to_transforms=False):
super(DualOutputSRG, self).__init__()
switches = []
self.initial_conv = ConvBnLelu(3, transformation_filters, norm=False, activation=False, bias=True)
self.fea_upconv1 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
self.fea_upconv2 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
self.fea_hr_conv = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
self.fea_final_conv = ConvBnLelu(transformation_filters, 3, norm=False, activation=False, bias=True)
self.upconv1 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
self.upconv2 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
self.hr_conv = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True)
self.final_conv = ConvBnLelu(transformation_filters, 3, norm=False, activation=False, bias=True)
for _ in range(switch_depth):
multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions, switch_processing_layers, trans_counts)
pretransform_fn = functools.partial(ConvBnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1)
transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), transformation_filters, kernel_size=trans_kernel_sizes, depth=trans_layers, weight_init_factor=.1)
switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn, transform_block=transform_fn,
transform_count=trans_counts, init_temp=initial_temp,
add_scalable_noise_to_transforms=add_scalable_noise_to_transforms))
self.switches = nn.ModuleList(switches)
self.transformation_counts = trans_counts
self.init_temperature = initial_temp
self.final_temperature_step = final_temperature_step
self.heightened_temp_min = heightened_temp_min
self.heightened_final_step = heightened_final_step
self.attentions = None
self.upsample_factor = upsample_factor
assert self.upsample_factor == 2 or self.upsample_factor == 4
def forward(self, x):
x = self.initial_conv(x)
self.attentions = []
for i, sw in enumerate(self.switches):
x, att = sw.forward(x, True)
self.attentions.append(att)
if i == len(self.switches)-2:
fea = self.fea_upconv1(F.interpolate(x, scale_factor=2, mode="nearest"))
if self.upsample_factor > 2:
fea = F.interpolate(fea, scale_factor=2, mode="nearest")
fea = self.fea_upconv2(fea)
fea = self.fea_final_conv(self.hr_conv(fea))
x = self.upconv1(F.interpolate(x, scale_factor=2, mode="nearest"))
if self.upsample_factor > 2:
x = F.interpolate(x, scale_factor=2, mode="nearest")
x = self.upconv2(x)
return fea, self.final_conv(self.hr_conv(x))
def set_temperature(self, temp):
[sw.set_temperature(temp) for sw in self.switches]
def update_for_step(self, step, experiments_path='.'):
if self.attentions:
temp = max(1, int(self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step))
if temp == 1 and self.heightened_final_step and self.heightened_final_step != 1:
# Once the temperature passes (1) it enters an inverted curve to match the linear curve from above.
# without this, the attention specificity "spikes" incredibly fast in the last few iterations.
h_steps_total = self.heightened_final_step - self.final_temperature_step
h_steps_current = max(min(step - self.final_temperature_step, h_steps_total), 1)
# The "gap" will represent the steps that need to be traveled as a linear function.
h_gap = 1 / self.heightened_temp_min
temp = h_gap * h_steps_current / h_steps_total
# Invert temperature to represent reality on this side of the curve
temp = 1 / temp
self.set_temperature(temp)
if step % 50 == 0:
[save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts, step, "a%i" % (i+1,)) for i in range(len(self.switches))]
def get_debug_values(self, step):
temp = self.switches[0].switch.temperature
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
means = [i[0] for i in mean_hists]
hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists]
val = {"switch_temperature": temp}
for i in range(len(means)):
val["switch_%i_specificity" % (i,)] = means[i]
val["switch_%i_histogram" % (i,)] = hists[i]
return val
def load_state_dict(self, state_dict, strict=True):
# Support backwards compatibility where accumulator_index and accumulator_filled are not in this state_dict
t_state = self.state_dict()
if 'switches.0.switch.attention_norm.accumulator_index' not in state_dict.keys():
for i in range(4):
state_dict['switches.%i.switch.attention_norm.accumulator' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator' % (i,)]
state_dict['switches.%i.switch.attention_norm.accumulator_index' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_index' % (i,)]
state_dict['switches.%i.switch.attention_norm.accumulator_filled' % (i,)] = t_state['switches.%i.switch.attention_norm.accumulator_filled' % (i,)]
super(DualOutputSRG, self).load_state_dict(state_dict, strict)

View File

@ -51,7 +51,6 @@ class Discriminator_VGG_128(nn.Module):
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
x = x[0]
fea = self.lrelu(self.conv0_0(x))
fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
@ -127,7 +126,6 @@ class Discriminator_VGG_PixLoss(nn.Module):
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x, flatten=True):
x = x[0]
fea0 = self.lrelu(self.conv0_0(x))
fea0 = self.lrelu(self.bn0_1(self.conv0_1(fea0)))
@ -205,7 +203,6 @@ class Discriminator_UNet(nn.Module):
self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False)
def forward(self, x, flatten=True):
x = x[0]
fea0 = self.conv0_0(x)
fea0 = self.conv0_1(fea0)

View File

@ -78,6 +78,15 @@ def define_G(opt, net_key='network_G'):
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
elif which_model == "DualOutputSRG":
netG = SwitchedGen_arch.DualOutputSRG(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'],
switch_reductions=opt_net['switch_reductions'],
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
transformation_filters=opt_net['transformation_filters'],
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
# image corruption
elif which_model == 'HighToLowResNet':

View File

@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg2.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_dual_srg.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)

View File

@ -93,7 +93,7 @@ if __name__ == "__main__":
torch.randn(1, 3, 64, 64),
device='cuda')
'''
test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator2,
test_stability(functools.partial(srg.DualOutputSRG,
switch_depth=4,
switch_filters=64,
switch_reductions=4,