Add switched discriminator

The logic is that the discriminator may be incapable of providing a truly
targeted loss for all image regions since it has to be too generic
(basically the same argument for the switched generator). So add some
switches in! See how it works!
This commit is contained in:
James Betker 2020-07-22 20:52:59 -06:00
parent 8a0a1569f3
commit dbf6147504
5 changed files with 168 additions and 6 deletions

View File

@ -116,9 +116,16 @@ class SRGANModel(BaseModel):
weight_decay=wd_G, weight_decay=wd_G,
betas=(train_opt['beta1_G'], train_opt['beta2_G'])) betas=(train_opt['beta1_G'], train_opt['beta2_G']))
self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_G)
optim_params = []
for k, v in self.netD.named_parameters(): # can optimize for a part of the model
if v.requires_grad:
optim_params.append(v)
else:
if self.rank <= 0:
logger.warning('Params [{:s}] will not optimize.'.format(k))
# D # D
wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], self.optimizer_D = torch.optim.Adam(optim_params, lr=train_opt['lr_D'],
weight_decay=wd_D, weight_decay=wd_D,
betas=(train_opt['beta1_D'], train_opt['beta2_D'])) betas=(train_opt['beta1_D'], train_opt['beta2_D']))
self.optimizers.append(self.optimizer_D) self.optimizers.append(self.optimizer_D)
@ -219,6 +226,8 @@ class SRGANModel(BaseModel):
# Some generators have variants depending on the current step. # Some generators have variants depending on the current step.
if hasattr(self.netG.module, "update_for_step"): if hasattr(self.netG.module, "update_for_step"):
self.netG.module.update_for_step(step, os.path.join(self.opt['path']['models'], "..")) self.netG.module.update_for_step(step, os.path.join(self.opt['path']['models'], ".."))
if hasattr(self.netD.module, "update_for_step"):
self.netD.module.update_for_step(step, os.path.join(self.opt['path']['models'], ".."))
# G # G
for p in self.netD.parameters(): for p in self.netD.parameters():
@ -323,7 +332,8 @@ class SRGANModel(BaseModel):
# D # D
if self.l_gan_w > 0 and step > self.G_warmup: if self.l_gan_w > 0 and step > self.G_warmup:
for p in self.netD.parameters(): for p in self.netD.parameters():
p.requires_grad = True if p.dtype != torch.int64 and p.dtype != torch.bool:
p.requires_grad = True
noise = torch.randn_like(var_ref) * noise_theta noise = torch.randn_like(var_ref) * noise_theta
noise.to(self.device) noise.to(self.device)
@ -610,6 +620,8 @@ class SRGANModel(BaseModel):
# Some generators can do their own metric logging. # Some generators can do their own metric logging.
if hasattr(self.netG.module, "get_debug_values"): if hasattr(self.netG.module, "get_debug_values"):
return_log.update(self.netG.module.get_debug_values(step)) return_log.update(self.netG.module.get_debug_values(step))
if hasattr(self.netD.module, "get_debug_values"):
return_log.update(self.netD.module.get_debug_values(step))
return return_log return return_log

View File

@ -238,6 +238,155 @@ class Discriminator_UNet(nn.Module):
return 3, 4 return 3, 4
import functools
from models.archs.SwitchedResidualGenerator_arch import MultiConvBlock, ConfigurableSwitchComputer, BareConvSwitch
from switched_conv_util import save_attention_to_image
from switched_conv import compute_attention_specificity, AttentionNorm
class ExpandAndCollapse(nn.Module):
def __init__(self, nf, nf_out, num_channels):
super(ExpandAndCollapse, self).__init__()
self.expand = ExpansionBlock(nf, nf_out, block=ConvGnLelu)
self.collapse = ConvGnLelu(nf_out, num_channels, norm=False, bias=False, activation=False)
def forward(self, x, passthrough):
x = self.expand(x, passthrough)
return self.collapse(x)
# Differs from ConfigurableSwitchComputer in that the connections are not residual and the multiplexer is fed directly in.
class ConfigurableLinearSwitchComputer(nn.Module):
def __init__(self, out_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, attention_norm,
init_temp=20, add_scalable_noise_to_transforms=False):
super(ConfigurableLinearSwitchComputer, self).__init__()
self.multiplexer = multiplexer_net
self.pre_transform = pre_transform_block
self.transforms = nn.ModuleList([transform_block() for _ in range(transform_count)])
self.add_noise = add_scalable_noise_to_transforms
self.noise_scale = nn.Parameter(torch.full((1,), float(1e-3)))
# And the switch itself, including learned scalars
self.switch = BareConvSwitch(initial_temperature=init_temp, attention_norm=AttentionNorm(transform_count, accumulator_size=16 * transform_count) if attention_norm else None)
self.post_switch_conv = ConvBnLelu(out_filters, out_filters, norm=False, bias=True)
# The post_switch_conv gets a low scale initially. The network can decide to magnify it (or not)
# depending on its needs.
self.psc_scale = nn.Parameter(torch.full((1,), float(.1)))
def forward(self, x, passthrough, output_attention_weights=False, extra_arg=None):
identity = x
if self.add_noise:
rand_feature = torch.randn_like(x) * self.noise_scale
x = x + rand_feature
x = self.pre_transform(x)
xformed = [t.forward(x, passthrough) for t in self.transforms]
m = self.multiplexer(identity, passthrough)
outputs, attention = self.switch(xformed, m, True)
outputs = self.post_switch_conv(outputs)
if output_attention_weights:
return outputs, attention
else:
return outputs
def set_temperature(self, temp):
self.switch.set_attention_temperature(temp)
def create_switched_upsampler(nf, nf_out, num_channels, initial_temp=10):
multiplx = ExpandAndCollapse(nf, nf_out, num_channels)
pretransform = ConvGnLelu(nf, nf, norm=True, bias=False)
transform_fn = functools.partial(ExpansionBlock, nf, nf_out, block=ConvGnLelu)
return ConfigurableLinearSwitchComputer(nf_out, multiplx,
pre_transform_block=pretransform, transform_block=transform_fn,
attention_norm=True,
transform_count=num_channels, init_temp=initial_temp,
add_scalable_noise_to_transforms=False)
class Discriminator_switched(nn.Module):
def __init__(self, in_nc, nf, initial_temp=10, final_temperature_step=50000):
super(Discriminator_switched, self).__init__()
# [64, 128, 128]
self.conv0_0 = ConvGnLelu(in_nc, nf, kernel_size=3, bias=True, activation=False)
self.conv0_1 = ConvGnLelu(nf, nf, kernel_size=3, stride=2, bias=False)
# [64, 64, 64]
self.conv1_0 = ConvGnLelu(nf, nf * 2, kernel_size=3, bias=False)
self.conv1_1 = ConvGnLelu(nf * 2, nf * 2, kernel_size=3, stride=2, bias=False)
# [128, 32, 32]
self.conv2_0 = ConvGnLelu(nf * 2, nf * 4, kernel_size=3, bias=False)
self.conv2_1 = ConvGnLelu(nf * 4, nf * 4, kernel_size=3, stride=2, bias=False)
# [256, 16, 16]
self.conv3_0 = ConvGnLelu(nf * 4, nf * 8, kernel_size=3, bias=False)
self.conv3_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
# [512, 8, 8]
self.conv4_0 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=False)
self.conv4_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
self.exp1 = ExpansionBlock(nf * 8, nf * 8, block=ConvGnLelu)
self.upsw2 = create_switched_upsampler(nf * 8, nf * 4, 8)
self.upsw3 = create_switched_upsampler(nf * 4, nf * 2, 8)
self.switches = [self.upsw2, self.upsw3]
self.proc3 = ConvGnLelu(nf * 2, nf * 2, bias=False)
self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False)
self.init_temperature = initial_temp
self.final_temperature_step = final_temperature_step
self.attentions = None
def forward(self, x, flatten=True):
fea0 = self.conv0_0(x)
fea0 = self.conv0_1(fea0)
fea1 = self.conv1_0(fea0)
fea1 = self.conv1_1(fea1)
fea2 = self.conv2_0(fea1)
fea2 = self.conv2_1(fea2)
fea3 = self.conv3_0(fea2)
fea3 = self.conv3_1(fea3)
fea4 = self.conv4_0(fea3)
fea4 = self.conv4_1(fea4)
u1 = self.exp1(fea4, fea3)
u2, a1 = self.upsw2(u1, fea2, output_attention_weights=True)
u3, a2 = self.upsw3(u2, fea1, output_attention_weights=True)
self.attentions = [a1, a2]
loss3 = self.collapse3(self.proc3(u3))
return loss3.view(-1, 1)
def pixgan_parameters(self):
return 1, 4
def set_temperature(self, temp):
[sw.set_temperature(temp) for sw in self.switches]
def update_for_step(self, step, experiments_path='.'):
if self.attentions:
for i, sw in enumerate(self.switches):
temp_loss_per_step = (self.init_temperature - 1) / self.final_temperature_step
sw.set_temperature(min(self.init_temperature,
max(self.init_temperature - temp_loss_per_step * step, 1)))
if step % 50 == 0:
[save_attention_to_image(experiments_path, self.attentions[i], 8, step, "disc_a%i" % (i+1,), l_mult=10) for i in range(len(self.attentions))]
def get_debug_values(self, step):
temp = self.switches[0].switch.temperature
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
means = [i[0] for i in mean_hists]
hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists]
val = {"disc_switch_temperature": temp}
for i in range(len(means)):
val["disc_switch_%i_specificity" % (i,)] = means[i]
val["disc_switch_%i_histogram" % (i,)] = hists[i]
return val
class Discriminator_UNet_FeaOut(nn.Module): class Discriminator_UNet_FeaOut(nn.Module):
def __init__(self, in_nc, nf, feature_mode=False): def __init__(self, in_nc, nf, feature_mode=False):
super(Discriminator_UNet_FeaOut, self).__init__() super(Discriminator_UNet_FeaOut, self).__init__()

View File

@ -124,6 +124,9 @@ def define_D(opt):
netD = SRGAN_arch.Discriminator_UNet(in_nc=opt_net['in_nc'], nf=opt_net['nf']) netD = SRGAN_arch.Discriminator_UNet(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
elif which_model == "discriminator_unet_fea": elif which_model == "discriminator_unet_fea":
netD = SRGAN_arch.Discriminator_UNet_FeaOut(in_nc=opt_net['in_nc'], nf=opt_net['nf'], feature_mode=opt_net['feature_mode']) netD = SRGAN_arch.Discriminator_UNet_FeaOut(in_nc=opt_net['in_nc'], nf=opt_net['nf'], feature_mode=opt_net['feature_mode'])
elif which_model == "discriminator_switched":
netD = SRGAN_arch.Discriminator_switched(in_nc=opt_net['in_nc'], nf=opt_net['nf'], initial_temp=opt_net['initial_temp'],
final_temperature_step=opt_net['final_temperature_step'])
else: else:
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
return netD return netD

View File

@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
def main(): def main():
#### options #### options
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_progressive_srg2.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg2_switched_disc.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher') help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
@ -161,7 +161,7 @@ def main():
current_step = resume_state['iter'] current_step = resume_state['iter']
model.resume_training(resume_state) # handle optimizers and schedulers model.resume_training(resume_state) # handle optimizers and schedulers
else: else:
current_step = 0 current_step = -1
start_epoch = 0 start_epoch = 0
#### training #### training

View File

@ -42,8 +42,6 @@ def copy_state_dict(dict_from, dict_to):
if __name__ == "__main__": if __name__ == "__main__":
os.chdir("..") os.chdir("..")
torch.backends.cudnn.benchmark = True
want_just_images = True
model_from, opt_from = get_model_for_opt_file("../options/train_imgset_pixgan_progressive_srg2.yml") model_from, opt_from = get_model_for_opt_file("../options/train_imgset_pixgan_progressive_srg2.yml")
model_to, _ = get_model_for_opt_file("../options/train_imgset_pixgan_progressive_srg2_.yml") model_to, _ = get_model_for_opt_file("../options/train_imgset_pixgan_progressive_srg2_.yml")