diff --git a/codes/models/archs/stylegan2.py b/codes/models/archs/stylegan2.py index ec0a2804..ed5bd38b 100644 --- a/codes/models/archs/stylegan2.py +++ b/codes/models/archs/stylegan2.py @@ -544,23 +544,43 @@ class StyleGan2GeneratorWithLatent(nn.Module): def noise(self, n, latent_dim, device): return torch.randn(n, latent_dim).cuda(device) + def noise_list(self, n, layers, latent_dim, device): + return [(self.noise(n, latent_dim, device), layers)] + + def mixed_list(self, n, layers, latent_dim, device): + tt = int(torch.rand(()).numpy() * layers) + return self.noise_list(n, tt, latent_dim, device) + self.noise_list(n, layers - tt, latent_dim, device) + + def latent_to_w(self, style_vectorizer, latent_descr): + return [(style_vectorizer(z), num_layers) for z, num_layers in latent_descr] + + def styles_def_to_tensor(self, styles_def): + return torch.cat([t[:, None, :].expand(-1, n, -1) for t, n in styles_def], dim=1) + # To use per the stylegan paper, input should be uniform noise. This gen takes it in as a normal "image" format: # b,f,h,w. def forward(self, x): b, f, h, w = x.shape - style = self.noise(b*2, self.gen.latent_dim, x.device) - w = self.vectorizer(style) - # Randomly distribute styles across layers - w_styles = w[:,None,:].expand(-1, self.gen.num_layers, -1).clone() - for j in range(b): - cutoff = int(torch.rand(()).numpy() * self.gen.num_layers) - if cutoff == self.gen.num_layers or random() > self.mixed_prob: - w_styles[j] = w_styles[j*2] - else: - w_styles[j, :cutoff] = w_styles[j*2, :cutoff] - w_styles[j, cutoff:] = w_styles[j*2+1, cutoff:] - w_styles = w_styles[:b] + full_random_latents = False + if full_random_latents: + style = self.noise(b*2, self.gen.latent_dim, x.device) + w = self.vectorizer(style) + # Randomly distribute styles across layers + w_styles = w[:,None,:].expand(-1, self.gen.num_layers, -1).clone() + for j in range(b): + cutoff = int(torch.rand(()).numpy() * self.gen.num_layers) + if cutoff == self.gen.num_layers or random() > self.mixed_prob: + w_styles[j] = w_styles[j*2] + else: + w_styles[j, :cutoff] = w_styles[j*2, :cutoff] + w_styles[j, cutoff:] = w_styles[j*2+1, cutoff:] + w_styles = w_styles[:b] + else: + get_latents_fn = self.mixed_list if random() < self.mixed_prob else self.noise_list + style = get_latents_fn(b, self.gen.num_layers, self.gen.latent_dim, device=x.device) + w_space = self.latent_to_w(self.vectorizer, style) + w_styles = self.styles_def_to_tensor(w_space) # The underlying model expects the noise as b,h,w,1. Make it so. return self.gen(w_styles, x[:,0,:,:].unsqueeze(dim=3)), w_styles @@ -628,7 +648,7 @@ class StyleGan2Discriminator(nn.Module): quantize_loss = torch.zeros(1).to(x) for (block, attn_block, q_block) in zip(self.blocks, self.attn_blocks, self.quantize_blocks): - x = checkpoint(block, x) + x = block(x) if exists(attn_block): x = attn_block(x) diff --git a/codes/train2.py b/codes/train2.py index ad4ddb1c..ec5211f6 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -274,7 +274,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_faster.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()