Modifications to allow partially trained stylegan discriminators to be used

This commit is contained in:
James Betker 2021-01-03 16:37:18 -07:00
parent 5e7ade0114
commit 4d8064c32c
6 changed files with 58 additions and 11 deletions

View File

@ -317,6 +317,9 @@ class StyleGan2Augmentor(nn.Module):
return self.D(images) return self.D(images)
def network_loaded(self):
self.D.network_loaded()
# stylegan2 classes # stylegan2 classes
@ -738,6 +741,7 @@ class StyleGan2GeneratorWithLatent(nn.Module):
class DiscriminatorBlock(nn.Module): class DiscriminatorBlock(nn.Module):
def __init__(self, input_channels, filters, downsample=True): def __init__(self, input_channels, filters, downsample=True):
super().__init__() super().__init__()
self.filters = filters
self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1)) self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1))
self.net = nn.Sequential( self.net = nn.Sequential(
@ -763,7 +767,7 @@ class DiscriminatorBlock(nn.Module):
class StyleGan2Discriminator(nn.Module): class StyleGan2Discriminator(nn.Module):
def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[], def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[],
transparent=False, fmap_max=512, input_filters=3, quantize=False, do_checkpointing=False): transparent=False, fmap_max=512, input_filters=3, quantize=False, do_checkpointing=False, mlp=False):
super().__init__() super().__init__()
num_layers = int(log2(image_size) - 1) num_layers = int(log2(image_size) - 1)
@ -805,7 +809,11 @@ class StyleGan2Discriminator(nn.Module):
self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1) self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1)
self.flatten = Flatten() self.flatten = Flatten()
self.to_logit = nn.Linear(latent_dim, 1) if mlp:
self.to_logit = nn.Sequential(nn.Linear(latent_dim, 100),
nn.Linear(100, 1))
else:
self.to_logit = nn.Linear(latent_dim, 1)
self._init_weights() self._init_weights()
@ -840,6 +848,38 @@ class StyleGan2Discriminator(nn.Module):
if type(m) in {nn.Conv2d, nn.Linear}: if type(m) in {nn.Conv2d, nn.Linear}:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
# Configures the network as partially pre-trained. This means:
# 1) The top (high-resolution) `num_blocks` will have their weights re-initialized.
# 2) The haed (linear layers) will also have their weights re-initialized
# 3) All intermediate blocks will be frozen until step `frozen_until_step`
# These settings will be applied after the weights have been loaded (network_loaded())
def configure_partial_training(self, bypass_blocks=0, num_blocks=2, frozen_until_step=0):
self.bypass_blocks = bypass_blocks
self.num_blocks = num_blocks
self.frozen_until_step = frozen_until_step
# Called after the network weights are loaded.
def network_loaded(self):
if not hasattr(self, 'frozen_until_step'):
return
if self.bypass_blocks > 0:
self.blocks = self.blocks[self.bypass_blocks:]
self.blocks[0] = DiscriminatorBlock(3, self.blocks[0].filters, downsample=True).to(next(self.parameters()).device)
reset_blocks = [self.to_logit]
for i in range(self.num_blocks):
reset_blocks.append(self.blocks[i])
for bl in reset_blocks:
for m in bl.modules():
if type(m) in {nn.Conv2d, nn.Linear}:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
for p in m.parameters(recurse=True):
p._NEW_BLOCK = True
for p in self.parameters():
if not hasattr(p, '_NEW_BLOCK'):
p.DO_NOT_TRAIN_UNTIL = self.frozen_until_step
class StyleGan2DivergenceLoss(L.ConfigurableLoss): class StyleGan2DivergenceLoss(L.ConfigurableLoss):
def __init__(self, opt, env): def __init__(self, opt, env):
@ -916,5 +956,8 @@ def register_stylegan2_discriminator(opt_net, opt):
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else [] attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn, disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn,
do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False), do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False),
quantize=opt_get(opt_net, ['quantize'], False)) quantize=opt_get(opt_net, ['quantize'], False),
mlp=opt_get(opt_net, ['mlp_head'], True))
if 'use_partial_pretrained' in opt_net.keys():
disc.configure_partial_training(opt_net['bypass_blocks'], opt_net['partial_training_blocks'], opt_net['intermediate_blocks_frozen_until'])
return StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability']) return StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])

View File

@ -3,8 +3,8 @@ import torch
from models.spinenet_arch import SpineNet from models.spinenet_arch import SpineNet
if __name__ == '__main__': if __name__ == '__main__':
pretrained_path = '../../experiments/train_sbyol_512unsupervised_restart/models/48000_generator.pth' pretrained_path = '../../experiments/byol_discriminator.pth'
output_path = '../../experiments/spinenet49_imgset_sbyol.pth' output_path = '../../experiments/byol_discriminator_extracted.pth'
wrap_key = 'online_encoder.net.' wrap_key = 'online_encoder.net.'
sd = torch.load(pretrained_path) sd = torch.load(pretrained_path)
@ -13,8 +13,8 @@ if __name__ == '__main__':
if wrap_key in k: if wrap_key in k:
sdo[k.replace(wrap_key, '')] = v sdo[k.replace(wrap_key, '')] = v
model = SpineNet('49', in_channels=3, use_input_norm=True).to('cuda') #model = SpineNet('49', in_channels=3, use_input_norm=True).to('cuda')
model.load_state_dict(sdo, strict=True) #model.load_state_dict(sdo, strict=True)
print("Validation succeeded, dumping state dict to output path.") print("Validation succeeded, dumping state dict to output path.")
torch.save(sdo, output_path) torch.save(sdo, output_path)

View File

@ -106,7 +106,7 @@ def get_latent_for_img(model, img):
def find_similar_latents(model, compare_fn=structural_euc_dist): def find_similar_latents(model, compare_fn=structural_euc_dist):
global layer_hooked_value global layer_hooked_value
img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\poon.jpg' img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\adrianna_xx.jpg'
#img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\nicky_xx.jpg' #img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\nicky_xx.jpg'
output_path = '../../results/byol_resnet_similars' output_path = '../../results/byol_resnet_similars'
os.makedirs(output_path, exist_ok=True) os.makedirs(output_path, exist_ok=True)
@ -141,7 +141,7 @@ def find_similar_latents(model, compare_fn=structural_euc_dist):
if __name__ == '__main__': if __name__ == '__main__':
pretrained_path = '../../experiments/resnet_byol_diffframe_85k.pth' pretrained_path = '../../experiments/resnet_byol_diffframe_115k.pth'
model = resnet50(pretrained=False).to('cuda') model = resnet50(pretrained=False).to('cuda')
sd = torch.load(pretrained_path) sd = torch.load(pretrained_path)
resnet_sd = {} resnet_sd = {}

View File

@ -295,7 +295,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_discriminator_diffimage.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_stylesr.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()

View File

@ -209,7 +209,8 @@ class ExtensibleTrainer(BaseModel):
if 'after' in self.opt['networks'][name].keys() and step < self.opt['networks'][name]['after']: if 'after' in self.opt['networks'][name].keys() and step < self.opt['networks'][name]['after']:
net_enabled = False net_enabled = False
for p in net.parameters(): for p in net.parameters():
if p.dtype != torch.int64 and p.dtype != torch.bool and not hasattr(p, "DO_NOT_TRAIN"): do_not_train_flag = hasattr(p, "DO_NOT_TRAIN") or (hasattr(p, "DO_NOT_TRAIN_UNTIL") and step < p.DO_NOT_TRAIN_UNTIL)
if p.dtype != torch.int64 and p.dtype != torch.bool and not do_not_train_flag:
p.requires_grad = net_enabled p.requires_grad = net_enabled
else: else:
p.requires_grad = False p.requires_grad = False
@ -357,6 +358,8 @@ class ExtensibleTrainer(BaseModel):
if self.rank <= 0: if self.rank <= 0:
logger.info('Loading model for [%s]' % (load_path,)) logger.info('Loading model for [%s]' % (load_path,))
self.load_network(load_path, net, self.opt['path']['strict_load']) self.load_network(load_path, net, self.opt['path']['strict_load'])
if hasattr(net.module, 'network_loaded'):
net.module.network_loaded()
def save(self, iter_step): def save(self, iter_step):
for name, net in self.networks.items(): for name, net in self.networks.items():

View File

@ -58,6 +58,7 @@ class BaseModel():
def update_learning_rate(self, cur_iter, warmup_iter=-1): def update_learning_rate(self, cur_iter, warmup_iter=-1):
for scheduler in self.schedulers: for scheduler in self.schedulers:
scheduler.last_epoch = cur_iter
scheduler.step() scheduler.step()
# set up warm-up learning rate # set up warm-up learning rate
if cur_iter < warmup_iter: if cur_iter < warmup_iter: