forked from mrq/DL-Art-School
Allow optimizers to train separate param groups, add higher dimensional VGG discriminator
Did this to support training 512x512px networks off of a pretrained 256x256 network.
This commit is contained in:
parent
193cdc6636
commit
bdbab65082
|
@ -169,6 +169,97 @@ class Discriminator_VGG_128_GN(nn.Module):
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_discriminator_vgg_128(opt_net, opt):
|
def register_discriminator_vgg_128(opt_net, opt):
|
||||||
return Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=opt_net['image_size'],
|
return Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'],
|
||||||
|
input_img_factor=opt_net['image_size'] / 128,
|
||||||
extra_conv=opt_get(opt_net, ['extra_conv'], False),
|
extra_conv=opt_get(opt_net, ['extra_conv'], False),
|
||||||
do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False))
|
do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False))
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminatorVGG448GN(nn.Module):
|
||||||
|
# input_img_factor = multiplier to support images over 128x128. Only certain factors are supported.
|
||||||
|
def __init__(self, in_nc, nf, do_checkpointing=False):
|
||||||
|
super().__init__()
|
||||||
|
self.do_checkpointing = do_checkpointing
|
||||||
|
|
||||||
|
# 448x448
|
||||||
|
self.convn1_0 = nn.Conv2d(in_nc, nf // 2, 3, 1, 1, bias=True)
|
||||||
|
self.convn1_1 = nn.Conv2d(nf // 2, nf // 2, 4, 2, 1, bias=False)
|
||||||
|
self.bnn1_1 = nn.GroupNorm(8, nf // 2, affine=True)
|
||||||
|
|
||||||
|
# 224x224 (new head)
|
||||||
|
self.conv0_0_new = nn.Conv2d(nf // 2, nf, 3, 1, 1, bias=False)
|
||||||
|
self.bn0_0 = nn.GroupNorm(8, nf, affine=True)
|
||||||
|
# 224x224 (old head)
|
||||||
|
self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) # Unused.
|
||||||
|
self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
|
||||||
|
self.bn0_1 = nn.GroupNorm(8, nf, affine=True)
|
||||||
|
# 112x112
|
||||||
|
self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
|
||||||
|
self.bn1_0 = nn.GroupNorm(8, nf * 2, affine=True)
|
||||||
|
self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
|
||||||
|
self.bn1_1 = nn.GroupNorm(8, nf * 2, affine=True)
|
||||||
|
# 56x56
|
||||||
|
self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
|
||||||
|
self.bn2_0 = nn.GroupNorm(8, nf * 4, affine=True)
|
||||||
|
self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
|
||||||
|
self.bn2_1 = nn.GroupNorm(8, nf * 4, affine=True)
|
||||||
|
# 28x28
|
||||||
|
self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False)
|
||||||
|
self.bn3_0 = nn.GroupNorm(8, nf * 8, affine=True)
|
||||||
|
self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
|
||||||
|
self.bn3_1 = nn.GroupNorm(8, nf * 8, affine=True)
|
||||||
|
# 14x14
|
||||||
|
self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
|
||||||
|
self.bn4_0 = nn.GroupNorm(8, nf * 8, affine=True)
|
||||||
|
self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
|
||||||
|
self.bn4_1 = nn.GroupNorm(8, nf * 8, affine=True)
|
||||||
|
|
||||||
|
# out: 7x7
|
||||||
|
|
||||||
|
# activation function
|
||||||
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
|
|
||||||
|
final_nf = nf * 8
|
||||||
|
self.linear1 = nn.Linear(int(final_nf * 7 * 7), 100)
|
||||||
|
self.linear2 = nn.Linear(100, 1)
|
||||||
|
|
||||||
|
# Assign all new heads to the new param group.2
|
||||||
|
for m in [self.convn1_0, self.convn1_1, self.bnn1_1, self.conv0_0_new, self.bn0_0]:
|
||||||
|
for p in m.parameters():
|
||||||
|
p.PARAM_GROUP = 'new_head'
|
||||||
|
|
||||||
|
def compute_body(self, x):
|
||||||
|
fea = self.lrelu(self.convn1_0(x))
|
||||||
|
fea = self.lrelu(self.bnn1_1(self.convn1_1(fea)))
|
||||||
|
|
||||||
|
fea = self.lrelu(self.bn0_0(self.conv0_0_new(fea)))
|
||||||
|
# fea = self.lrelu(self.conv0_0(x)) <- replaced
|
||||||
|
fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
|
||||||
|
|
||||||
|
fea = self.lrelu(self.bn1_0(self.conv1_0(fea)))
|
||||||
|
fea = self.lrelu(self.bn1_1(self.conv1_1(fea)))
|
||||||
|
|
||||||
|
fea = self.lrelu(self.bn2_0(self.conv2_0(fea)))
|
||||||
|
fea = self.lrelu(self.bn2_1(self.conv2_1(fea)))
|
||||||
|
|
||||||
|
fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
|
||||||
|
fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
|
||||||
|
|
||||||
|
fea = self.lrelu(self.bn4_0(self.conv4_0(fea)))
|
||||||
|
fea = self.lrelu(self.bn4_1(self.conv4_1(fea)))
|
||||||
|
return fea
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.do_checkpointing:
|
||||||
|
fea = checkpoint(self.compute_body, x)
|
||||||
|
else:
|
||||||
|
fea = self.compute_body(x)
|
||||||
|
fea = fea.contiguous().view(fea.size(0), -1)
|
||||||
|
fea = self.lrelu(self.linear1(fea))
|
||||||
|
out = self.linear2(fea)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def register_discriminator_vgg_448(opt_net, opt):
|
||||||
|
return DiscriminatorVGG448GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
|
||||||
|
|
|
@ -89,6 +89,7 @@ class Trainer:
|
||||||
seed = random.randint(1, 10000)
|
seed = random.randint(1, 10000)
|
||||||
if self.rank <= 0:
|
if self.rank <= 0:
|
||||||
self.logger.info('Random seed: {}'.format(seed))
|
self.logger.info('Random seed: {}'.format(seed))
|
||||||
|
seed += self.rank # Different multiprocessing instances should behave differently.
|
||||||
util.set_random_seed(seed)
|
util.set_random_seed(seed)
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
@ -293,7 +294,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_discriminator_diffimage.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_rrdb_bigboi_512.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -325,7 +325,8 @@ class ExtensibleTrainer(BaseModel):
|
||||||
|
|
||||||
# Log learning rate (from first param group) too.
|
# Log learning rate (from first param group) too.
|
||||||
for o in self.optimizers:
|
for o in self.optimizers:
|
||||||
log['learning_rate_%s' % (o._config['network'],)] = o.param_groups[0]['lr']
|
for pgi, pg in enumerate(o.param_groups):
|
||||||
|
log['learning_rate_%s_%i' % (o._config['network'], pgi)] = pg['lr']
|
||||||
return log
|
return log
|
||||||
|
|
||||||
def get_current_visuals(self, need_GT=True):
|
def get_current_visuals(self, need_GT=True):
|
||||||
|
|
|
@ -64,7 +64,13 @@ class ConfigurableStep(Module):
|
||||||
training = [training]
|
training = [training]
|
||||||
self.optimizers = []
|
self.optimizers = []
|
||||||
for net_name, net, opt_config in zip(training, nets, opt_configs):
|
for net_name, net, opt_config in zip(training, nets, opt_configs):
|
||||||
optim_params = []
|
# Configs can organize parameters by-group and specify different learning rates for each group. This only
|
||||||
|
# works in the model specifically annotates which parameters belong in which group using PARAM_GROUP.
|
||||||
|
optim_params = {'default': {'params': [], 'lr': opt_config['lr']}}
|
||||||
|
if 'param_groups' in opt_config.keys():
|
||||||
|
for k, pg in opt_config['param_groups'].items():
|
||||||
|
optim_params[k] = {'params': [], 'lr': pg['lr']}
|
||||||
|
|
||||||
for k, v in net.named_parameters(): # can optimize for a part of the model
|
for k, v in net.named_parameters(): # can optimize for a part of the model
|
||||||
# Make some inference about these parameters, which can be used by some optimizers to treat certain
|
# Make some inference about these parameters, which can be used by some optimizers to treat certain
|
||||||
# parameters differently. For example, it is considered good practice to not do weight decay on
|
# parameters differently. For example, it is considered good practice to not do weight decay on
|
||||||
|
@ -76,14 +82,23 @@ class ConfigurableStep(Module):
|
||||||
v.is_weight = True
|
v.is_weight = True
|
||||||
if ".bn" in k or '.batchnorm' in k or '.bnorm' in k:
|
if ".bn" in k or '.batchnorm' in k or '.bnorm' in k:
|
||||||
v.is_bn = True
|
v.is_bn = True
|
||||||
|
# Some models can specify some parameters to be in different groups.
|
||||||
|
param_group = "default"
|
||||||
|
if hasattr(v, 'PARAM_GROUP'):
|
||||||
|
if v.PARAM_GROUP in optim_params.keys():
|
||||||
|
param_group = v.PARAM_GROUP
|
||||||
|
else:
|
||||||
|
logger.warning(f'Model specifies a custom param group {v.PARAM_GROUP} which is not configured. '
|
||||||
|
f'The same LR will be used for all parameters.')
|
||||||
|
|
||||||
if v.requires_grad:
|
if v.requires_grad:
|
||||||
optim_params.append(v)
|
optim_params[param_group]['params'].append(v)
|
||||||
else:
|
else:
|
||||||
if self.env['rank'] <= 0:
|
if self.env['rank'] <= 0:
|
||||||
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
||||||
|
|
||||||
if 'optimizer' not in self.step_opt.keys() or self.step_opt['optimizer'] == 'adam':
|
if 'optimizer' not in self.step_opt.keys() or self.step_opt['optimizer'] == 'adam':
|
||||||
opt = torch.optim.Adam(optim_params, lr=opt_config['lr'],
|
opt = torch.optim.Adam(list(optim_params.values()),
|
||||||
weight_decay=opt_config['weight_decay'],
|
weight_decay=opt_config['weight_decay'],
|
||||||
betas=(opt_config['beta1'], opt_config['beta2']))
|
betas=(opt_config['beta1'], opt_config['beta2']))
|
||||||
elif self.step_opt['optimizer'] == 'lars':
|
elif self.step_opt['optimizer'] == 'lars':
|
||||||
|
|
Loading…
Reference in New Issue
Block a user