diff --git a/codes/models/discriminator_vgg_arch.py b/codes/models/discriminator_vgg_arch.py index ebae1284..f6fa50a4 100644 --- a/codes/models/discriminator_vgg_arch.py +++ b/codes/models/discriminator_vgg_arch.py @@ -169,6 +169,97 @@ class Discriminator_VGG_128_GN(nn.Module): @register_model def register_discriminator_vgg_128(opt_net, opt): - return Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=opt_net['image_size'], + return Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], + input_img_factor=opt_net['image_size'] / 128, extra_conv=opt_get(opt_net, ['extra_conv'], False), do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False)) + + +class DiscriminatorVGG448GN(nn.Module): + # input_img_factor = multiplier to support images over 128x128. Only certain factors are supported. + def __init__(self, in_nc, nf, do_checkpointing=False): + super().__init__() + self.do_checkpointing = do_checkpointing + + # 448x448 + self.convn1_0 = nn.Conv2d(in_nc, nf // 2, 3, 1, 1, bias=True) + self.convn1_1 = nn.Conv2d(nf // 2, nf // 2, 4, 2, 1, bias=False) + self.bnn1_1 = nn.GroupNorm(8, nf // 2, affine=True) + + # 224x224 (new head) + self.conv0_0_new = nn.Conv2d(nf // 2, nf, 3, 1, 1, bias=False) + self.bn0_0 = nn.GroupNorm(8, nf, affine=True) + # 224x224 (old head) + self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) # Unused. + self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) + self.bn0_1 = nn.GroupNorm(8, nf, affine=True) + # 112x112 + self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) + self.bn1_0 = nn.GroupNorm(8, nf * 2, affine=True) + self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) + self.bn1_1 = nn.GroupNorm(8, nf * 2, affine=True) + # 56x56 + self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) + self.bn2_0 = nn.GroupNorm(8, nf * 4, affine=True) + self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) + self.bn2_1 = nn.GroupNorm(8, nf * 4, affine=True) + # 28x28 + self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) + self.bn3_0 = nn.GroupNorm(8, nf * 8, affine=True) + self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) + self.bn3_1 = nn.GroupNorm(8, nf * 8, affine=True) + # 14x14 + self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) + self.bn4_0 = nn.GroupNorm(8, nf * 8, affine=True) + self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) + self.bn4_1 = nn.GroupNorm(8, nf * 8, affine=True) + + # out: 7x7 + + # activation function + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + final_nf = nf * 8 + self.linear1 = nn.Linear(int(final_nf * 7 * 7), 100) + self.linear2 = nn.Linear(100, 1) + + # Assign all new heads to the new param group.2 + for m in [self.convn1_0, self.convn1_1, self.bnn1_1, self.conv0_0_new, self.bn0_0]: + for p in m.parameters(): + p.PARAM_GROUP = 'new_head' + + def compute_body(self, x): + fea = self.lrelu(self.convn1_0(x)) + fea = self.lrelu(self.bnn1_1(self.convn1_1(fea))) + + fea = self.lrelu(self.bn0_0(self.conv0_0_new(fea))) + # fea = self.lrelu(self.conv0_0(x)) <- replaced + fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) + + fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) + fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) + + fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) + fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) + + fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) + fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) + + fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) + fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) + return fea + + def forward(self, x): + if self.do_checkpointing: + fea = checkpoint(self.compute_body, x) + else: + fea = self.compute_body(x) + fea = fea.contiguous().view(fea.size(0), -1) + fea = self.lrelu(self.linear1(fea)) + out = self.linear2(fea) + return out + + +@register_model +def register_discriminator_vgg_448(opt_net, opt): + return DiscriminatorVGG448GN(in_nc=opt_net['in_nc'], nf=opt_net['nf']) diff --git a/codes/train.py b/codes/train.py index 5343abe0..5a819376 100644 --- a/codes/train.py +++ b/codes/train.py @@ -89,6 +89,7 @@ class Trainer: seed = random.randint(1, 10000) if self.rank <= 0: self.logger.info('Random seed: {}'.format(seed)) + seed += self.rank # Different multiprocessing instances should behave differently. util.set_random_seed(seed) torch.backends.cudnn.benchmark = True @@ -293,7 +294,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_discriminator_diffimage.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_rrdb_bigboi_512.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 1ce50c73..2737142c 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -325,7 +325,8 @@ class ExtensibleTrainer(BaseModel): # Log learning rate (from first param group) too. for o in self.optimizers: - log['learning_rate_%s' % (o._config['network'],)] = o.param_groups[0]['lr'] + for pgi, pg in enumerate(o.param_groups): + log['learning_rate_%s_%i' % (o._config['network'], pgi)] = pg['lr'] return log def get_current_visuals(self, need_GT=True): diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index 0e491d3d..fb9d7033 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -64,7 +64,13 @@ class ConfigurableStep(Module): training = [training] self.optimizers = [] for net_name, net, opt_config in zip(training, nets, opt_configs): - optim_params = [] + # Configs can organize parameters by-group and specify different learning rates for each group. This only + # works in the model specifically annotates which parameters belong in which group using PARAM_GROUP. + optim_params = {'default': {'params': [], 'lr': opt_config['lr']}} + if 'param_groups' in opt_config.keys(): + for k, pg in opt_config['param_groups'].items(): + optim_params[k] = {'params': [], 'lr': pg['lr']} + for k, v in net.named_parameters(): # can optimize for a part of the model # Make some inference about these parameters, which can be used by some optimizers to treat certain # parameters differently. For example, it is considered good practice to not do weight decay on @@ -76,14 +82,23 @@ class ConfigurableStep(Module): v.is_weight = True if ".bn" in k or '.batchnorm' in k or '.bnorm' in k: v.is_bn = True + # Some models can specify some parameters to be in different groups. + param_group = "default" + if hasattr(v, 'PARAM_GROUP'): + if v.PARAM_GROUP in optim_params.keys(): + param_group = v.PARAM_GROUP + else: + logger.warning(f'Model specifies a custom param group {v.PARAM_GROUP} which is not configured. ' + f'The same LR will be used for all parameters.') + if v.requires_grad: - optim_params.append(v) + optim_params[param_group]['params'].append(v) else: if self.env['rank'] <= 0: logger.warning('Params [{:s}] will not optimize.'.format(k)) if 'optimizer' not in self.step_opt.keys() or self.step_opt['optimizer'] == 'adam': - opt = torch.optim.Adam(optim_params, lr=opt_config['lr'], + opt = torch.optim.Adam(list(optim_params.values()), weight_decay=opt_config['weight_decay'], betas=(opt_config['beta1'], opt_config['beta2'])) elif self.step_opt['optimizer'] == 'lars':