diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index cc189a11..649087c2 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -178,10 +178,10 @@ class ExtensibleTrainer(BaseModel): state = self.dstate for step_num, s in enumerate(self.steps): # Skip steps if mod_step doesn't line up. - if 'mod_step' in s.opt.keys() and step % s.opt['mod_step'] != 0: + if 'mod_step' in s.step_opt.keys() and step % s.step_opt['mod_step'] != 0: continue # Steps can opt out of early (or late) training, make sure that happens here. - if 'after' in s.opt.keys() and step < s.opt['after'] or 'before' in s.opt.keys() and step > s.opt['before']: + if 'after' in s.step_opt.keys() and step < s.step_opt['after'] or 'before' in s.step_opt.keys() and step > s.step_opt['before']: continue # Only set requires_grad=True for the network being trained. diff --git a/codes/models/archs/feature_arch.py b/codes/models/archs/feature_arch.py index 6d182231..bb2b4371 100644 --- a/codes/models/archs/feature_arch.py +++ b/codes/models/archs/feature_arch.py @@ -6,7 +6,7 @@ import torch.nn.functional as F # Utilizes pretrained torchvision modules for feature extraction class VGGFeatureExtractor(nn.Module): - def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, + def __init__(self, feature_layers=[34], use_bn=False, use_input_norm=True, device=torch.device('cpu')): super(VGGFeatureExtractor, self).__init__() self.use_input_norm = use_input_norm @@ -21,7 +21,8 @@ class VGGFeatureExtractor(nn.Module): # [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1] self.register_buffer('mean', mean) self.register_buffer('std', std) - self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)]) + self.feature_layers = feature_layers + self.features = nn.Sequential(*list(model.features.children())[:(max(feature_layers) + 1)]) # No need to BP to variable for k, v in self.features.named_parameters(): v.requires_grad = False diff --git a/codes/models/networks.py b/codes/models/networks.py index e74cd96f..8f56ddd9 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -159,18 +159,19 @@ def define_fixed_D(opt): # Define network used for perceptual loss -def define_F(which_model='vgg', use_bn=False, for_training=False, load_path=None): +def define_F(which_model='vgg', use_bn=False, for_training=False, load_path=None, feature_layers=None): if which_model == 'vgg': # PyTorch pretrained VGG19-54, before ReLU. - if use_bn: - feature_layer = 49 - else: - feature_layer = 34 + if feature_layers is None: + if use_bn: + feature_layers = [49] + else: + feature_layers = [34] if for_training: - netF = feature_arch.TrainableVGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, + netF = feature_arch.TrainableVGGFeatureExtractor(feature_layers=feature_layers, use_bn=use_bn, use_input_norm=True) else: - netF = feature_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, + netF = feature_arch.VGGFeatureExtractor(feature_layers=feature_layers, use_bn=use_bn, use_input_norm=True) elif which_model == 'wide_resnet': netF = feature_arch.WideResnetFeatureExtractor(use_input_norm=True) diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 259f91ee..1953c35c 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -113,20 +113,35 @@ class GeneratorGanLoss(ConfigurableLoss): super(GeneratorGanLoss, self).__init__(opt, env) self.opt = opt self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device']) + self.noise = None if 'noise' not in opt.keys() else opt['noise'] + self.detach_real = opt['detach_real'] if 'detach_real' in opt.keys() else True def forward(self, _, state): netD = self.env['discriminators'][self.opt['discriminator']] + real = extract_params_from_state(self.opt['real'], state) fake = extract_params_from_state(self.opt['fake'], state) + if self.noise: + nreal = [] + nfake = [] + for i, t in enumerate(real): + if isinstance(t, torch.Tensor): + nreal.append(t + torch.randn_like(t) * self.noise) + nfake.append(fake[i] + torch.randn_like(t) * self.noise) + else: + nreal.append(t) + nfake.append(fake[i]) + real = nreal + fake = nfake if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea']: pred_g_fake = netD(*fake) return self.criterion(pred_g_fake, True) elif self.opt['gan_type'] == 'ragan': - real = extract_params_from_state(self.opt['real'], state) - real = [r.detach() for r in real] - pred_d_real = netD(*real).detach() + pred_d_real = netD(*real) + if self.detach_real: + pred_d_real = pred_d_real.detach() pred_g_fake = netD(*fake) - return (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + - self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 + return (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) + + self.criterion(pred_g_fake - torch.mean(pred_d_real), True)) / 2 else: raise NotImplementedError @@ -142,6 +157,7 @@ class DiscriminatorGanLoss(ConfigurableLoss): def forward(self, net, state): self.metrics = [] real = extract_params_from_state(self.opt['real'], state) + real = [r.detach() for r in real] fake = extract_params_from_state(self.opt['fake'], state) fake = [f.detach() for f in fake] if self.noise: @@ -159,17 +175,18 @@ class DiscriminatorGanLoss(ConfigurableLoss): d_real = net(*real) d_fake = net(*fake) - self.metrics.append(("d_fake", torch.mean(d_fake))) - self.metrics.append(("d_real", torch.mean(d_real))) - if self.opt['gan_type'] in ['gan', 'pixgan']: + self.metrics.append(("d_fake", torch.mean(d_fake))) + self.metrics.append(("d_real", torch.mean(d_real))) l_real = self.criterion(d_real, True) l_fake = self.criterion(d_fake, False) l_total = l_real + l_fake return l_total elif self.opt['gan_type'] == 'ragan': + d_fake_diff = d_fake - torch.mean(d_real) + self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff))) return (self.criterion(d_real - torch.mean(d_fake), True) + - self.criterion(d_fake - torch.mean(d_real), False)) + self.criterion(d_fake_diff, False)) else: raise NotImplementedError