Add option to work with nonrandom latents
This commit is contained in:
parent
566b99ca75
commit
080ad61be4
|
@ -544,23 +544,43 @@ class StyleGan2GeneratorWithLatent(nn.Module):
|
||||||
def noise(self, n, latent_dim, device):
|
def noise(self, n, latent_dim, device):
|
||||||
return torch.randn(n, latent_dim).cuda(device)
|
return torch.randn(n, latent_dim).cuda(device)
|
||||||
|
|
||||||
|
def noise_list(self, n, layers, latent_dim, device):
|
||||||
|
return [(self.noise(n, latent_dim, device), layers)]
|
||||||
|
|
||||||
|
def mixed_list(self, n, layers, latent_dim, device):
|
||||||
|
tt = int(torch.rand(()).numpy() * layers)
|
||||||
|
return self.noise_list(n, tt, latent_dim, device) + self.noise_list(n, layers - tt, latent_dim, device)
|
||||||
|
|
||||||
|
def latent_to_w(self, style_vectorizer, latent_descr):
|
||||||
|
return [(style_vectorizer(z), num_layers) for z, num_layers in latent_descr]
|
||||||
|
|
||||||
|
def styles_def_to_tensor(self, styles_def):
|
||||||
|
return torch.cat([t[:, None, :].expand(-1, n, -1) for t, n in styles_def], dim=1)
|
||||||
|
|
||||||
# To use per the stylegan paper, input should be uniform noise. This gen takes it in as a normal "image" format:
|
# To use per the stylegan paper, input should be uniform noise. This gen takes it in as a normal "image" format:
|
||||||
# b,f,h,w.
|
# b,f,h,w.
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
b, f, h, w = x.shape
|
b, f, h, w = x.shape
|
||||||
style = self.noise(b*2, self.gen.latent_dim, x.device)
|
|
||||||
w = self.vectorizer(style)
|
|
||||||
|
|
||||||
# Randomly distribute styles across layers
|
full_random_latents = False
|
||||||
w_styles = w[:,None,:].expand(-1, self.gen.num_layers, -1).clone()
|
if full_random_latents:
|
||||||
for j in range(b):
|
style = self.noise(b*2, self.gen.latent_dim, x.device)
|
||||||
cutoff = int(torch.rand(()).numpy() * self.gen.num_layers)
|
w = self.vectorizer(style)
|
||||||
if cutoff == self.gen.num_layers or random() > self.mixed_prob:
|
# Randomly distribute styles across layers
|
||||||
w_styles[j] = w_styles[j*2]
|
w_styles = w[:,None,:].expand(-1, self.gen.num_layers, -1).clone()
|
||||||
else:
|
for j in range(b):
|
||||||
w_styles[j, :cutoff] = w_styles[j*2, :cutoff]
|
cutoff = int(torch.rand(()).numpy() * self.gen.num_layers)
|
||||||
w_styles[j, cutoff:] = w_styles[j*2+1, cutoff:]
|
if cutoff == self.gen.num_layers or random() > self.mixed_prob:
|
||||||
w_styles = w_styles[:b]
|
w_styles[j] = w_styles[j*2]
|
||||||
|
else:
|
||||||
|
w_styles[j, :cutoff] = w_styles[j*2, :cutoff]
|
||||||
|
w_styles[j, cutoff:] = w_styles[j*2+1, cutoff:]
|
||||||
|
w_styles = w_styles[:b]
|
||||||
|
else:
|
||||||
|
get_latents_fn = self.mixed_list if random() < self.mixed_prob else self.noise_list
|
||||||
|
style = get_latents_fn(b, self.gen.num_layers, self.gen.latent_dim, device=x.device)
|
||||||
|
w_space = self.latent_to_w(self.vectorizer, style)
|
||||||
|
w_styles = self.styles_def_to_tensor(w_space)
|
||||||
|
|
||||||
# The underlying model expects the noise as b,h,w,1. Make it so.
|
# The underlying model expects the noise as b,h,w,1. Make it so.
|
||||||
return self.gen(w_styles, x[:,0,:,:].unsqueeze(dim=3)), w_styles
|
return self.gen(w_styles, x[:,0,:,:].unsqueeze(dim=3)), w_styles
|
||||||
|
@ -628,7 +648,7 @@ class StyleGan2Discriminator(nn.Module):
|
||||||
quantize_loss = torch.zeros(1).to(x)
|
quantize_loss = torch.zeros(1).to(x)
|
||||||
|
|
||||||
for (block, attn_block, q_block) in zip(self.blocks, self.attn_blocks, self.quantize_blocks):
|
for (block, attn_block, q_block) in zip(self.blocks, self.attn_blocks, self.quantize_blocks):
|
||||||
x = checkpoint(block, x)
|
x = block(x)
|
||||||
|
|
||||||
if exists(attn_block):
|
if exists(attn_block):
|
||||||
x = attn_block(x)
|
x = attn_block(x)
|
||||||
|
|
|
@ -274,7 +274,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_faster.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user