misc adjustments for stylegan
This commit is contained in:
parent
b687ef4cd0
commit
17555e7d07
|
@ -558,7 +558,12 @@ class Generator(nn.Module):
|
|||
randomize_noise=True,
|
||||
):
|
||||
if not input_is_latent:
|
||||
styles = [self.style(s) for s in styles]
|
||||
if self.training:
|
||||
# In training mode, multiple style vectors are fed to the generator.
|
||||
styles = [self.style(s) for s in styles]
|
||||
else:
|
||||
# In eval mode, only a single style is fed.
|
||||
styles = [self.style(styles)]
|
||||
|
||||
if noise is None:
|
||||
if randomize_noise:
|
||||
|
|
|
@ -57,8 +57,9 @@ def get_image_folder_dataloader(batch_size, num_workers, target_size=224, shuffl
|
|||
'name': 'amalgam',
|
||||
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\pn_coven\\cropped2'],
|
||||
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
|
||||
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_tiled_filtered_flattened'],
|
||||
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_tiled_filtered_flattened'],
|
||||
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
|
||||
'paths': ['E:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
|
||||
'weights': [1],
|
||||
'target_size': target_size,
|
||||
'force_multiple': 32,
|
||||
|
@ -116,6 +117,7 @@ def produce_latent_dict(model):
|
|||
latents = []
|
||||
for batch in tqdm(dataloader):
|
||||
hq = batch['hq'].to('cuda')
|
||||
hq = F.interpolate(F.interpolate(hq, size=(16,16), mode='bilinear'), size=(224,244))
|
||||
model(hq)
|
||||
l = layer_hooked_value.cpu().split(1, dim=0)
|
||||
latents.extend(l)
|
||||
|
@ -202,7 +204,7 @@ if __name__ == '__main__':
|
|||
register_hook(model, 'avgpool')
|
||||
|
||||
with torch.no_grad():
|
||||
find_similar_latents(model, structural_euc_dist)
|
||||
#produce_latent_dict(model)
|
||||
#find_similar_latents(model, structural_euc_dist)
|
||||
produce_latent_dict(model)
|
||||
#build_kmeans()
|
||||
#use_kmeans()
|
||||
|
|
|
@ -13,11 +13,14 @@ from tqdm import tqdm
|
|||
import torch
|
||||
|
||||
|
||||
def forward_pass(model, output_dir, alteration_suffix=''):
|
||||
def forward_pass(model, output_dir, opt):
|
||||
alteration_suffix = util.opt_get(opt, ['name'], '')
|
||||
denorm_range = tuple(util.opt_get(opt, ['image_normalization_range'], [0, 1]))
|
||||
model.feed_data(data, 0, need_GT=need_GT)
|
||||
model.test()
|
||||
|
||||
visuals = model.get_current_visuals(need_GT)['rlt'].cpu()
|
||||
visuals = (visuals - denorm_range[0]) / (denorm_range[1]-denorm_range[0])
|
||||
fea_loss = 0
|
||||
psnr_loss = 0
|
||||
for i in range(visuals.shape[0]):
|
||||
|
@ -48,7 +51,7 @@ if __name__ == "__main__":
|
|||
torch.backends.cudnn.benchmark = True
|
||||
want_metrics = False
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_mi1.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_cats_stylegan2_rosinality.yml')
|
||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
utils.util.loaded_options = opt
|
||||
|
@ -90,7 +93,7 @@ if __name__ == "__main__":
|
|||
need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
|
||||
need_GT = need_GT and want_metrics
|
||||
|
||||
fea_loss, psnr_loss = forward_pass(model, dataset_dir, opt['name'])
|
||||
fea_loss, psnr_loss = forward_pass(model, dataset_dir, opt)
|
||||
fea_loss += fea_loss
|
||||
psnr_loss += psnr_loss
|
||||
|
||||
|
|
|
@ -295,7 +295,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_lightweight_gan_pna.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_cats_stylegan2_rosinality.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -259,8 +259,8 @@ class ExtensibleTrainer(BaseModel):
|
|||
|
||||
# Record visual outputs for usage in debugging and testing.
|
||||
if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and step % self.opt['logger']['visual_debug_rate'] == 0:
|
||||
denorm = opt_get(self.opt, ['logger', 'denormalize'], False)
|
||||
denorm_range = opt_get(self.opt, ['logger', 'denormalize_range'], None)
|
||||
denorm = 'image_normalization_range' in self.opt.keys()
|
||||
denorm_range = opt_get(self.opt, ['image_normalization_range'], None)
|
||||
if denorm_range:
|
||||
denorm_range = tuple(denorm_range)
|
||||
sample_save_path = os.path.join(self.opt['path']['models'], "..", "visual_dbg")
|
||||
|
|
|
@ -18,11 +18,13 @@ class StyleTransferEvaluator(evaluator.Evaluator):
|
|||
self.gen_output_index = opt_eval['gen_index'] if 'gen_index' in opt_eval.keys() else 0
|
||||
self.noise_type = opt_get(opt_eval, ['noise_type'], 'imgnoise')
|
||||
self.latent_dim = opt_get(opt_eval, ['latent_dim'], 512) # Not needed if using 'imgnoise' input.
|
||||
self.image_norm_range = tuple(opt_get(env['opt'], ['image_normalization_range'], [0,1]))
|
||||
|
||||
def perform_eval(self):
|
||||
fid_fake_path = osp.join(self.env['base_path'], "../", "fid", str(self.env["step"]))
|
||||
os.makedirs(fid_fake_path, exist_ok=True)
|
||||
counter = 0
|
||||
self.model.eval()
|
||||
for i in range(self.batches_per_eval):
|
||||
if self.noise_type == 'imgnoise':
|
||||
batch = torch.FloatTensor(self.batch_sz, 3, self.im_sz, self.im_sz).uniform_(0., 1.).to(self.env['device'])
|
||||
|
@ -32,9 +34,11 @@ class StyleTransferEvaluator(evaluator.Evaluator):
|
|||
if not isinstance(gen, list) and not isinstance(gen, tuple):
|
||||
gen = [gen]
|
||||
gen = gen[self.gen_output_index]
|
||||
gen = (gen - self.image_norm_range[0]) / (self.image_norm_range[1]-self.image_norm_range[0])
|
||||
for b in range(self.batch_sz):
|
||||
torchvision.utils.save_image(gen[b], osp.join(fid_fake_path, "%i_.png" % (counter)))
|
||||
counter += 1
|
||||
self.model.train()
|
||||
|
||||
print("Got all images, computing fid")
|
||||
return {"fid": fid_score.calculate_fid_given_paths([self.fid_real_samples, fid_fake_path], self.batch_sz, True,
|
||||
|
|
Loading…
Reference in New Issue
Block a user