Add switched discriminator
The logic is that the discriminator may be incapable of providing a truly targeted loss for all image regions since it has to be too generic (basically the same argument for the switched generator). So add some switches in! See how it works!
This commit is contained in:
parent
8a0a1569f3
commit
dbf6147504
|
@ -116,9 +116,16 @@ class SRGANModel(BaseModel):
|
|||
weight_decay=wd_G,
|
||||
betas=(train_opt['beta1_G'], train_opt['beta2_G']))
|
||||
self.optimizers.append(self.optimizer_G)
|
||||
optim_params = []
|
||||
for k, v in self.netD.named_parameters(): # can optimize for a part of the model
|
||||
if v.requires_grad:
|
||||
optim_params.append(v)
|
||||
else:
|
||||
if self.rank <= 0:
|
||||
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
||||
# D
|
||||
wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
|
||||
self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'],
|
||||
self.optimizer_D = torch.optim.Adam(optim_params, lr=train_opt['lr_D'],
|
||||
weight_decay=wd_D,
|
||||
betas=(train_opt['beta1_D'], train_opt['beta2_D']))
|
||||
self.optimizers.append(self.optimizer_D)
|
||||
|
@ -219,6 +226,8 @@ class SRGANModel(BaseModel):
|
|||
# Some generators have variants depending on the current step.
|
||||
if hasattr(self.netG.module, "update_for_step"):
|
||||
self.netG.module.update_for_step(step, os.path.join(self.opt['path']['models'], ".."))
|
||||
if hasattr(self.netD.module, "update_for_step"):
|
||||
self.netD.module.update_for_step(step, os.path.join(self.opt['path']['models'], ".."))
|
||||
|
||||
# G
|
||||
for p in self.netD.parameters():
|
||||
|
@ -323,7 +332,8 @@ class SRGANModel(BaseModel):
|
|||
# D
|
||||
if self.l_gan_w > 0 and step > self.G_warmup:
|
||||
for p in self.netD.parameters():
|
||||
p.requires_grad = True
|
||||
if p.dtype != torch.int64 and p.dtype != torch.bool:
|
||||
p.requires_grad = True
|
||||
|
||||
noise = torch.randn_like(var_ref) * noise_theta
|
||||
noise.to(self.device)
|
||||
|
@ -610,6 +620,8 @@ class SRGANModel(BaseModel):
|
|||
# Some generators can do their own metric logging.
|
||||
if hasattr(self.netG.module, "get_debug_values"):
|
||||
return_log.update(self.netG.module.get_debug_values(step))
|
||||
if hasattr(self.netD.module, "get_debug_values"):
|
||||
return_log.update(self.netD.module.get_debug_values(step))
|
||||
|
||||
return return_log
|
||||
|
||||
|
|
|
@ -238,6 +238,155 @@ class Discriminator_UNet(nn.Module):
|
|||
return 3, 4
|
||||
|
||||
|
||||
import functools
|
||||
from models.archs.SwitchedResidualGenerator_arch import MultiConvBlock, ConfigurableSwitchComputer, BareConvSwitch
|
||||
from switched_conv_util import save_attention_to_image
|
||||
from switched_conv import compute_attention_specificity, AttentionNorm
|
||||
|
||||
|
||||
class ExpandAndCollapse(nn.Module):
|
||||
def __init__(self, nf, nf_out, num_channels):
|
||||
super(ExpandAndCollapse, self).__init__()
|
||||
self.expand = ExpansionBlock(nf, nf_out, block=ConvGnLelu)
|
||||
self.collapse = ConvGnLelu(nf_out, num_channels, norm=False, bias=False, activation=False)
|
||||
|
||||
def forward(self, x, passthrough):
|
||||
x = self.expand(x, passthrough)
|
||||
return self.collapse(x)
|
||||
|
||||
|
||||
# Differs from ConfigurableSwitchComputer in that the connections are not residual and the multiplexer is fed directly in.
|
||||
class ConfigurableLinearSwitchComputer(nn.Module):
|
||||
def __init__(self, out_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, attention_norm,
|
||||
init_temp=20, add_scalable_noise_to_transforms=False):
|
||||
super(ConfigurableLinearSwitchComputer, self).__init__()
|
||||
|
||||
self.multiplexer = multiplexer_net
|
||||
self.pre_transform = pre_transform_block
|
||||
self.transforms = nn.ModuleList([transform_block() for _ in range(transform_count)])
|
||||
self.add_noise = add_scalable_noise_to_transforms
|
||||
self.noise_scale = nn.Parameter(torch.full((1,), float(1e-3)))
|
||||
|
||||
# And the switch itself, including learned scalars
|
||||
self.switch = BareConvSwitch(initial_temperature=init_temp, attention_norm=AttentionNorm(transform_count, accumulator_size=16 * transform_count) if attention_norm else None)
|
||||
self.post_switch_conv = ConvBnLelu(out_filters, out_filters, norm=False, bias=True)
|
||||
# The post_switch_conv gets a low scale initially. The network can decide to magnify it (or not)
|
||||
# depending on its needs.
|
||||
self.psc_scale = nn.Parameter(torch.full((1,), float(.1)))
|
||||
|
||||
def forward(self, x, passthrough, output_attention_weights=False, extra_arg=None):
|
||||
identity = x
|
||||
if self.add_noise:
|
||||
rand_feature = torch.randn_like(x) * self.noise_scale
|
||||
x = x + rand_feature
|
||||
|
||||
x = self.pre_transform(x)
|
||||
xformed = [t.forward(x, passthrough) for t in self.transforms]
|
||||
m = self.multiplexer(identity, passthrough)
|
||||
|
||||
|
||||
outputs, attention = self.switch(xformed, m, True)
|
||||
outputs = self.post_switch_conv(outputs)
|
||||
if output_attention_weights:
|
||||
return outputs, attention
|
||||
else:
|
||||
return outputs
|
||||
|
||||
def set_temperature(self, temp):
|
||||
self.switch.set_attention_temperature(temp)
|
||||
|
||||
|
||||
def create_switched_upsampler(nf, nf_out, num_channels, initial_temp=10):
|
||||
multiplx = ExpandAndCollapse(nf, nf_out, num_channels)
|
||||
pretransform = ConvGnLelu(nf, nf, norm=True, bias=False)
|
||||
transform_fn = functools.partial(ExpansionBlock, nf, nf_out, block=ConvGnLelu)
|
||||
return ConfigurableLinearSwitchComputer(nf_out, multiplx,
|
||||
pre_transform_block=pretransform, transform_block=transform_fn,
|
||||
attention_norm=True,
|
||||
transform_count=num_channels, init_temp=initial_temp,
|
||||
add_scalable_noise_to_transforms=False)
|
||||
|
||||
|
||||
class Discriminator_switched(nn.Module):
|
||||
def __init__(self, in_nc, nf, initial_temp=10, final_temperature_step=50000):
|
||||
super(Discriminator_switched, self).__init__()
|
||||
# [64, 128, 128]
|
||||
self.conv0_0 = ConvGnLelu(in_nc, nf, kernel_size=3, bias=True, activation=False)
|
||||
self.conv0_1 = ConvGnLelu(nf, nf, kernel_size=3, stride=2, bias=False)
|
||||
# [64, 64, 64]
|
||||
self.conv1_0 = ConvGnLelu(nf, nf * 2, kernel_size=3, bias=False)
|
||||
self.conv1_1 = ConvGnLelu(nf * 2, nf * 2, kernel_size=3, stride=2, bias=False)
|
||||
# [128, 32, 32]
|
||||
self.conv2_0 = ConvGnLelu(nf * 2, nf * 4, kernel_size=3, bias=False)
|
||||
self.conv2_1 = ConvGnLelu(nf * 4, nf * 4, kernel_size=3, stride=2, bias=False)
|
||||
# [256, 16, 16]
|
||||
self.conv3_0 = ConvGnLelu(nf * 4, nf * 8, kernel_size=3, bias=False)
|
||||
self.conv3_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
|
||||
# [512, 8, 8]
|
||||
self.conv4_0 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=False)
|
||||
self.conv4_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
|
||||
|
||||
self.exp1 = ExpansionBlock(nf * 8, nf * 8, block=ConvGnLelu)
|
||||
self.upsw2 = create_switched_upsampler(nf * 8, nf * 4, 8)
|
||||
self.upsw3 = create_switched_upsampler(nf * 4, nf * 2, 8)
|
||||
self.switches = [self.upsw2, self.upsw3]
|
||||
self.proc3 = ConvGnLelu(nf * 2, nf * 2, bias=False)
|
||||
self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False)
|
||||
|
||||
self.init_temperature = initial_temp
|
||||
self.final_temperature_step = final_temperature_step
|
||||
self.attentions = None
|
||||
|
||||
def forward(self, x, flatten=True):
|
||||
fea0 = self.conv0_0(x)
|
||||
fea0 = self.conv0_1(fea0)
|
||||
|
||||
fea1 = self.conv1_0(fea0)
|
||||
fea1 = self.conv1_1(fea1)
|
||||
|
||||
fea2 = self.conv2_0(fea1)
|
||||
fea2 = self.conv2_1(fea2)
|
||||
|
||||
fea3 = self.conv3_0(fea2)
|
||||
fea3 = self.conv3_1(fea3)
|
||||
|
||||
fea4 = self.conv4_0(fea3)
|
||||
fea4 = self.conv4_1(fea4)
|
||||
|
||||
u1 = self.exp1(fea4, fea3)
|
||||
u2, a1 = self.upsw2(u1, fea2, output_attention_weights=True)
|
||||
u3, a2 = self.upsw3(u2, fea1, output_attention_weights=True)
|
||||
self.attentions = [a1, a2]
|
||||
loss3 = self.collapse3(self.proc3(u3))
|
||||
return loss3.view(-1, 1)
|
||||
|
||||
def pixgan_parameters(self):
|
||||
return 1, 4
|
||||
|
||||
def set_temperature(self, temp):
|
||||
[sw.set_temperature(temp) for sw in self.switches]
|
||||
|
||||
def update_for_step(self, step, experiments_path='.'):
|
||||
if self.attentions:
|
||||
for i, sw in enumerate(self.switches):
|
||||
temp_loss_per_step = (self.init_temperature - 1) / self.final_temperature_step
|
||||
sw.set_temperature(min(self.init_temperature,
|
||||
max(self.init_temperature - temp_loss_per_step * step, 1)))
|
||||
if step % 50 == 0:
|
||||
[save_attention_to_image(experiments_path, self.attentions[i], 8, step, "disc_a%i" % (i+1,), l_mult=10) for i in range(len(self.attentions))]
|
||||
|
||||
def get_debug_values(self, step):
|
||||
temp = self.switches[0].switch.temperature
|
||||
mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions]
|
||||
means = [i[0] for i in mean_hists]
|
||||
hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists]
|
||||
val = {"disc_switch_temperature": temp}
|
||||
for i in range(len(means)):
|
||||
val["disc_switch_%i_specificity" % (i,)] = means[i]
|
||||
val["disc_switch_%i_histogram" % (i,)] = hists[i]
|
||||
return val
|
||||
|
||||
|
||||
class Discriminator_UNet_FeaOut(nn.Module):
|
||||
def __init__(self, in_nc, nf, feature_mode=False):
|
||||
super(Discriminator_UNet_FeaOut, self).__init__()
|
||||
|
|
|
@ -124,6 +124,9 @@ def define_D(opt):
|
|||
netD = SRGAN_arch.Discriminator_UNet(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
|
||||
elif which_model == "discriminator_unet_fea":
|
||||
netD = SRGAN_arch.Discriminator_UNet_FeaOut(in_nc=opt_net['in_nc'], nf=opt_net['nf'], feature_mode=opt_net['feature_mode'])
|
||||
elif which_model == "discriminator_switched":
|
||||
netD = SRGAN_arch.Discriminator_switched(in_nc=opt_net['in_nc'], nf=opt_net['nf'], initial_temp=opt_net['initial_temp'],
|
||||
final_temperature_step=opt_net['final_temperature_step'])
|
||||
else:
|
||||
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||
return netD
|
||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_progressive_srg2.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg2_switched_disc.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
@ -161,7 +161,7 @@ def main():
|
|||
current_step = resume_state['iter']
|
||||
model.resume_training(resume_state) # handle optimizers and schedulers
|
||||
else:
|
||||
current_step = 0
|
||||
current_step = -1
|
||||
start_epoch = 0
|
||||
|
||||
#### training
|
||||
|
|
|
@ -42,8 +42,6 @@ def copy_state_dict(dict_from, dict_to):
|
|||
|
||||
if __name__ == "__main__":
|
||||
os.chdir("..")
|
||||
torch.backends.cudnn.benchmark = True
|
||||
want_just_images = True
|
||||
model_from, opt_from = get_model_for_opt_file("../options/train_imgset_pixgan_progressive_srg2.yml")
|
||||
model_to, _ = get_model_for_opt_file("../options/train_imgset_pixgan_progressive_srg2_.yml")
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user