misc adjustments for stylegan

This commit is contained in:
James Betker 2021-04-21 18:14:17 -06:00
parent b687ef4cd0
commit 17555e7d07
6 changed files with 24 additions and 10 deletions

View File

@ -558,7 +558,12 @@ class Generator(nn.Module):
randomize_noise=True, randomize_noise=True,
): ):
if not input_is_latent: if not input_is_latent:
styles = [self.style(s) for s in styles] if self.training:
# In training mode, multiple style vectors are fed to the generator.
styles = [self.style(s) for s in styles]
else:
# In eval mode, only a single style is fed.
styles = [self.style(styles)]
if noise is None: if noise is None:
if randomize_noise: if randomize_noise:

View File

@ -57,8 +57,9 @@ def get_image_folder_dataloader(batch_size, num_workers, target_size=224, shuffl
'name': 'amalgam', 'name': 'amalgam',
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\pn_coven\\cropped2'], #'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\pn_coven\\cropped2'],
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'], #'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_tiled_filtered_flattened'], #'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_tiled_filtered_flattened'],
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'], #'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
'paths': ['E:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
'weights': [1], 'weights': [1],
'target_size': target_size, 'target_size': target_size,
'force_multiple': 32, 'force_multiple': 32,
@ -116,6 +117,7 @@ def produce_latent_dict(model):
latents = [] latents = []
for batch in tqdm(dataloader): for batch in tqdm(dataloader):
hq = batch['hq'].to('cuda') hq = batch['hq'].to('cuda')
hq = F.interpolate(F.interpolate(hq, size=(16,16), mode='bilinear'), size=(224,244))
model(hq) model(hq)
l = layer_hooked_value.cpu().split(1, dim=0) l = layer_hooked_value.cpu().split(1, dim=0)
latents.extend(l) latents.extend(l)
@ -202,7 +204,7 @@ if __name__ == '__main__':
register_hook(model, 'avgpool') register_hook(model, 'avgpool')
with torch.no_grad(): with torch.no_grad():
find_similar_latents(model, structural_euc_dist) #find_similar_latents(model, structural_euc_dist)
#produce_latent_dict(model) produce_latent_dict(model)
#build_kmeans() #build_kmeans()
#use_kmeans() #use_kmeans()

View File

@ -13,11 +13,14 @@ from tqdm import tqdm
import torch import torch
def forward_pass(model, output_dir, alteration_suffix=''): def forward_pass(model, output_dir, opt):
alteration_suffix = util.opt_get(opt, ['name'], '')
denorm_range = tuple(util.opt_get(opt, ['image_normalization_range'], [0, 1]))
model.feed_data(data, 0, need_GT=need_GT) model.feed_data(data, 0, need_GT=need_GT)
model.test() model.test()
visuals = model.get_current_visuals(need_GT)['rlt'].cpu() visuals = model.get_current_visuals(need_GT)['rlt'].cpu()
visuals = (visuals - denorm_range[0]) / (denorm_range[1]-denorm_range[0])
fea_loss = 0 fea_loss = 0
psnr_loss = 0 psnr_loss = 0
for i in range(visuals.shape[0]): for i in range(visuals.shape[0]):
@ -48,7 +51,7 @@ if __name__ == "__main__":
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
want_metrics = False want_metrics = False
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_mi1.yml') parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_cats_stylegan2_rosinality.yml')
opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt) opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt utils.util.loaded_options = opt
@ -90,7 +93,7 @@ if __name__ == "__main__":
need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
need_GT = need_GT and want_metrics need_GT = need_GT and want_metrics
fea_loss, psnr_loss = forward_pass(model, dataset_dir, opt['name']) fea_loss, psnr_loss = forward_pass(model, dataset_dir, opt)
fea_loss += fea_loss fea_loss += fea_loss
psnr_loss += psnr_loss psnr_loss += psnr_loss

View File

@ -295,7 +295,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_lightweight_gan_pna.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_cats_stylegan2_rosinality.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()

View File

@ -259,8 +259,8 @@ class ExtensibleTrainer(BaseModel):
# Record visual outputs for usage in debugging and testing. # Record visual outputs for usage in debugging and testing.
if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and step % self.opt['logger']['visual_debug_rate'] == 0: if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and step % self.opt['logger']['visual_debug_rate'] == 0:
denorm = opt_get(self.opt, ['logger', 'denormalize'], False) denorm = 'image_normalization_range' in self.opt.keys()
denorm_range = opt_get(self.opt, ['logger', 'denormalize_range'], None) denorm_range = opt_get(self.opt, ['image_normalization_range'], None)
if denorm_range: if denorm_range:
denorm_range = tuple(denorm_range) denorm_range = tuple(denorm_range)
sample_save_path = os.path.join(self.opt['path']['models'], "..", "visual_dbg") sample_save_path = os.path.join(self.opt['path']['models'], "..", "visual_dbg")

View File

@ -18,11 +18,13 @@ class StyleTransferEvaluator(evaluator.Evaluator):
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0 self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
self.noise_type = opt_get(opt_eval, ['noise_type'], 'imgnoise') self.noise_type = opt_get(opt_eval, ['noise_type'], 'imgnoise')
self.latent_dim = opt_get(opt_eval, ['latent_dim'], 512) # Not needed if using 'imgnoise' input. self.latent_dim = opt_get(opt_eval, ['latent_dim'], 512) # Not needed if using 'imgnoise' input.
self.image_norm_range = tuple(opt_get(env['opt'], ['image_normalization_range'], [0,1]))
def perform_eval(self): def perform_eval(self):
fid_fake_path = osp.join(self.env['base_path'], "../", "fid", str(self.env["step"])) fid_fake_path = osp.join(self.env['base_path'], "../", "fid", str(self.env["step"]))
os.makedirs(fid_fake_path, exist_ok=True) os.makedirs(fid_fake_path, exist_ok=True)
counter = 0 counter = 0
self.model.eval()
for i in range(self.batches_per_eval): for i in range(self.batches_per_eval):
if self.noise_type == 'imgnoise': if self.noise_type == 'imgnoise':
batch = torch.FloatTensor(self.batch_sz, 3, self.im_sz, self.im_sz).uniform_(0., 1.).to(self.env['device']) batch = torch.FloatTensor(self.batch_sz, 3, self.im_sz, self.im_sz).uniform_(0., 1.).to(self.env['device'])
@ -32,9 +34,11 @@ class StyleTransferEvaluator(evaluator.Evaluator):
if not isinstance(gen, list) and not isinstance(gen, tuple): if not isinstance(gen, list) and not isinstance(gen, tuple):
gen = [gen] gen = [gen]
gen = gen[self.gen_output_index] gen = gen[self.gen_output_index]
gen = (gen - self.image_norm_range[0]) / (self.image_norm_range[1]-self.image_norm_range[0])
for b in range(self.batch_sz): for b in range(self.batch_sz):
torchvision.utils.save_image(gen[b], osp.join(fid_fake_path, "%i_.png" % (counter))) torchvision.utils.save_image(gen[b], osp.join(fid_fake_path, "%i_.png" % (counter)))
counter += 1 counter += 1
self.model.train()
print("Got all images, computing fid") print("Got all images, computing fid")
return {"fid": fid_score.calculate_fid_given_paths([self.fid_real_samples, fid_fake_path], self.batch_sz, True, return {"fid": fid_score.calculate_fid_given_paths([self.fid_real_samples, fid_fake_path], self.batch_sz, True,