forked from mrq/DL-Art-School
Undo lucidrains changes for new discriminator
This "new" code will live in the styledsr directory from now on.
This commit is contained in:
parent
40ec71da81
commit
2225fe6ac2
|
@ -317,9 +317,6 @@ class StyleGan2Augmentor(nn.Module):
|
||||||
|
|
||||||
return self.D(images)
|
return self.D(images)
|
||||||
|
|
||||||
def network_loaded(self):
|
|
||||||
self.D.network_loaded()
|
|
||||||
|
|
||||||
|
|
||||||
# stylegan2 classes
|
# stylegan2 classes
|
||||||
|
|
||||||
|
@ -741,7 +738,6 @@ class StyleGan2GeneratorWithLatent(nn.Module):
|
||||||
class DiscriminatorBlock(nn.Module):
|
class DiscriminatorBlock(nn.Module):
|
||||||
def __init__(self, input_channels, filters, downsample=True):
|
def __init__(self, input_channels, filters, downsample=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.filters = filters
|
|
||||||
self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1))
|
self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1))
|
||||||
|
|
||||||
self.net = nn.Sequential(
|
self.net = nn.Sequential(
|
||||||
|
@ -767,7 +763,7 @@ class DiscriminatorBlock(nn.Module):
|
||||||
|
|
||||||
class StyleGan2Discriminator(nn.Module):
|
class StyleGan2Discriminator(nn.Module):
|
||||||
def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[],
|
def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[],
|
||||||
transparent=False, fmap_max=512, input_filters=3, quantize=False, do_checkpointing=False, mlp=False):
|
transparent=False, fmap_max=512, input_filters=3, quantize=False, do_checkpointing=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
num_layers = int(log2(image_size) - 1)
|
num_layers = int(log2(image_size) - 1)
|
||||||
|
|
||||||
|
@ -809,12 +805,7 @@ class StyleGan2Discriminator(nn.Module):
|
||||||
|
|
||||||
self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1)
|
self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1)
|
||||||
self.flatten = Flatten()
|
self.flatten = Flatten()
|
||||||
if mlp:
|
self.to_logit = nn.Linear(latent_dim, 1)
|
||||||
self.to_logit = nn.Sequential(nn.Linear(latent_dim, 100),
|
|
||||||
leaky_relu(),
|
|
||||||
nn.Linear(100, 1))
|
|
||||||
else:
|
|
||||||
self.to_logit = nn.Linear(latent_dim, 1)
|
|
||||||
|
|
||||||
self._init_weights()
|
self._init_weights()
|
||||||
|
|
||||||
|
@ -849,38 +840,6 @@ class StyleGan2Discriminator(nn.Module):
|
||||||
if type(m) in {nn.Conv2d, nn.Linear}:
|
if type(m) in {nn.Conv2d, nn.Linear}:
|
||||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
||||||
|
|
||||||
# Configures the network as partially pre-trained. This means:
|
|
||||||
# 1) The top (high-resolution) `num_blocks` will have their weights re-initialized.
|
|
||||||
# 2) The haed (linear layers) will also have their weights re-initialized
|
|
||||||
# 3) All intermediate blocks will be frozen until step `frozen_until_step`
|
|
||||||
# These settings will be applied after the weights have been loaded (network_loaded())
|
|
||||||
def configure_partial_training(self, bypass_blocks=0, num_blocks=2, frozen_until_step=0):
|
|
||||||
self.bypass_blocks = bypass_blocks
|
|
||||||
self.num_blocks = num_blocks
|
|
||||||
self.frozen_until_step = frozen_until_step
|
|
||||||
|
|
||||||
# Called after the network weights are loaded.
|
|
||||||
def network_loaded(self):
|
|
||||||
if not hasattr(self, 'frozen_until_step'):
|
|
||||||
return
|
|
||||||
|
|
||||||
if self.bypass_blocks > 0:
|
|
||||||
self.blocks = self.blocks[self.bypass_blocks:]
|
|
||||||
self.blocks[0] = DiscriminatorBlock(3, self.blocks[0].filters, downsample=True).to(next(self.parameters()).device)
|
|
||||||
|
|
||||||
reset_blocks = [self.to_logit]
|
|
||||||
for i in range(self.num_blocks):
|
|
||||||
reset_blocks.append(self.blocks[i])
|
|
||||||
for bl in reset_blocks:
|
|
||||||
for m in bl.modules():
|
|
||||||
if type(m) in {nn.Conv2d, nn.Linear}:
|
|
||||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
|
||||||
for p in m.parameters(recurse=True):
|
|
||||||
p._NEW_BLOCK = True
|
|
||||||
for p in self.parameters():
|
|
||||||
if not hasattr(p, '_NEW_BLOCK'):
|
|
||||||
p.DO_NOT_TRAIN_UNTIL = self.frozen_until_step
|
|
||||||
|
|
||||||
|
|
||||||
class StyleGan2DivergenceLoss(L.ConfigurableLoss):
|
class StyleGan2DivergenceLoss(L.ConfigurableLoss):
|
||||||
def __init__(self, opt, env):
|
def __init__(self, opt, env):
|
||||||
|
@ -957,8 +916,5 @@ def register_stylegan2_discriminator(opt_net, opt):
|
||||||
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
|
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
|
||||||
disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn,
|
disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn,
|
||||||
do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False),
|
do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False),
|
||||||
quantize=opt_get(opt_net, ['quantize'], False),
|
quantize=opt_get(opt_net, ['quantize'], False))
|
||||||
mlp=opt_get(opt_net, ['mlp_head'], True))
|
|
||||||
if 'use_partial_pretrained' in opt_net.keys():
|
|
||||||
disc.configure_partial_training(opt_net['bypass_blocks'], opt_net['partial_training_blocks'], opt_net['intermediate_blocks_frozen_until'])
|
|
||||||
return StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
|
return StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
|
||||||
|
|
Loading…
Reference in New Issue
Block a user