Add option to work with nonrandom latents

This commit is contained in:
James Betker 2020-11-12 21:23:50 -07:00
parent 566b99ca75
commit 080ad61be4
2 changed files with 34 additions and 14 deletions

View File

@ -544,23 +544,43 @@ class StyleGan2GeneratorWithLatent(nn.Module):
def noise(self, n, latent_dim, device): def noise(self, n, latent_dim, device):
return torch.randn(n, latent_dim).cuda(device) return torch.randn(n, latent_dim).cuda(device)
def noise_list(self, n, layers, latent_dim, device):
return [(self.noise(n, latent_dim, device), layers)]
def mixed_list(self, n, layers, latent_dim, device):
tt = int(torch.rand(()).numpy() * layers)
return self.noise_list(n, tt, latent_dim, device) + self.noise_list(n, layers - tt, latent_dim, device)
def latent_to_w(self, style_vectorizer, latent_descr):
return [(style_vectorizer(z), num_layers) for z, num_layers in latent_descr]
def styles_def_to_tensor(self, styles_def):
return torch.cat([t[:, None, :].expand(-1, n, -1) for t, n in styles_def], dim=1)
# To use per the stylegan paper, input should be uniform noise. This gen takes it in as a normal "image" format: # To use per the stylegan paper, input should be uniform noise. This gen takes it in as a normal "image" format:
# b,f,h,w. # b,f,h,w.
def forward(self, x): def forward(self, x):
b, f, h, w = x.shape b, f, h, w = x.shape
style = self.noise(b*2, self.gen.latent_dim, x.device)
w = self.vectorizer(style)
# Randomly distribute styles across layers full_random_latents = False
w_styles = w[:,None,:].expand(-1, self.gen.num_layers, -1).clone() if full_random_latents:
for j in range(b): style = self.noise(b*2, self.gen.latent_dim, x.device)
cutoff = int(torch.rand(()).numpy() * self.gen.num_layers) w = self.vectorizer(style)
if cutoff == self.gen.num_layers or random() > self.mixed_prob: # Randomly distribute styles across layers
w_styles[j] = w_styles[j*2] w_styles = w[:,None,:].expand(-1, self.gen.num_layers, -1).clone()
else: for j in range(b):
w_styles[j, :cutoff] = w_styles[j*2, :cutoff] cutoff = int(torch.rand(()).numpy() * self.gen.num_layers)
w_styles[j, cutoff:] = w_styles[j*2+1, cutoff:] if cutoff == self.gen.num_layers or random() > self.mixed_prob:
w_styles = w_styles[:b] w_styles[j] = w_styles[j*2]
else:
w_styles[j, :cutoff] = w_styles[j*2, :cutoff]
w_styles[j, cutoff:] = w_styles[j*2+1, cutoff:]
w_styles = w_styles[:b]
else:
get_latents_fn = self.mixed_list if random() < self.mixed_prob else self.noise_list
style = get_latents_fn(b, self.gen.num_layers, self.gen.latent_dim, device=x.device)
w_space = self.latent_to_w(self.vectorizer, style)
w_styles = self.styles_def_to_tensor(w_space)
# The underlying model expects the noise as b,h,w,1. Make it so. # The underlying model expects the noise as b,h,w,1. Make it so.
return self.gen(w_styles, x[:,0,:,:].unsqueeze(dim=3)), w_styles return self.gen(w_styles, x[:,0,:,:].unsqueeze(dim=3)), w_styles
@ -628,7 +648,7 @@ class StyleGan2Discriminator(nn.Module):
quantize_loss = torch.zeros(1).to(x) quantize_loss = torch.zeros(1).to(x)
for (block, attn_block, q_block) in zip(self.blocks, self.attn_blocks, self.quantize_blocks): for (block, attn_block, q_block) in zip(self.blocks, self.attn_blocks, self.quantize_blocks):
x = checkpoint(block, x) x = block(x)
if exists(attn_block): if exists(attn_block):
x = attn_block(x) x = attn_block(x)

View File

@ -274,7 +274,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_faster.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()