Various changes to fix testing

This commit is contained in:
James Betker 2021-06-11 15:31:10 -06:00
parent 220f11a5e4
commit 65c474eecf
10 changed files with 48 additions and 29 deletions

View File

@ -3,18 +3,20 @@ import logging
import torch
import torch.utils.data
from utils.util import opt_get
def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
phase = dataset_opt['phase']
if phase == 'train':
if opt['dist']:
if opt_get(opt, ['dist'], False):
world_size = torch.distributed.get_world_size()
num_workers = dataset_opt['n_workers']
assert dataset_opt['batch_size'] % world_size == 0
batch_size = dataset_opt['batch_size'] // world_size
shuffle = False
else:
num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids'])
num_workers = dataset_opt['n_workers']
batch_size = dataset_opt['batch_size']
shuffle = True
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,

View File

@ -37,7 +37,10 @@ class ImageCorruptor:
def corrupt_images(self, imgs, return_entropy=False):
if self.num_corrupts == 0 and not self.fixed_corruptions:
return imgs
if return_entropy:
return imgs, []
else:
return imgs
if self.num_corrupts == 0:
augmentations = []

View File

@ -35,6 +35,8 @@ class ImageFolderDataset:
self.skip_lq = opt_get(opt, ['skip_lq'], False)
self.disable_flip = opt_get(opt, ['disable_flip'], False)
self.rgb_n1_to_1 = opt_get(opt, ['rgb_n1_to_1'], False)
self.force_square = opt_get(opt, ['force_square'], True)
self.fixed_parameters = {k: torch.tensor(v) for k, v in opt_get(opt, ['fixed_parameters'], {}).items()}
if 'normalize' in opt.keys():
if opt['normalize'] == 'stylegan2_norm':
self.normalize = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
@ -44,7 +46,8 @@ class ImageFolderDataset:
raise Exception('Unsupported normalize')
else:
self.normalize = None
assert (self.target_hq_size // self.scale) % self.multiple == 0 # If we dont throw here, we get some really obscure errors.
if self.target_hq_size is not None:
assert (self.target_hq_size // self.scale) % self.multiple == 0 # If we dont throw here, we get some really obscure errors.
if not isinstance(self.paths, list):
self.paths = [self.paths]
self.weights = [1]
@ -129,10 +132,10 @@ class ImageFolderDataset:
if not self.disable_flip and random.random() < .5:
hq = hq[:, ::-1, :]
# We must convert the image into a square.
h, w, _ = hq.shape
dim = min(h, w)
hq = hq[(h - dim) // 2:dim + (h - dim) // 2, (w - dim) // 2:dim + (w - dim) // 2, :]
if self.force_square:
h, w, _ = hq.shape
dim = min(h, w)
hq = hq[(h - dim) // 2:dim + (h - dim) // 2, (w - dim) // 2:dim + (w - dim) // 2, :]
if self.labeler:
assert hq.shape[0] == hq.shape[1] # This just has not been accomodated yet.
@ -211,6 +214,7 @@ class ImageFolderDataset:
v = v * 2 - 1
out_dict[k] = v
out_dict.update(self.fixed_parameters)
return out_dict
if __name__ == '__main__':

View File

@ -241,8 +241,8 @@ class BYOL(nn.Module):
torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,)))
def forward(self, image_one, image_two):
image_one = self.aug(image_one)
image_two = self.aug(image_two)
image_one = self.aug(image_one.clone())
image_two = self.aug(image_two.clone())
# Keep copies on hand for visual_dbg.
self.im1 = image_one.detach().clone()

View File

@ -1,5 +1,6 @@
import os.path as osp
import logging
import random
import time
import argparse
from collections import OrderedDict
@ -11,9 +12,10 @@ from trainer.ExtensibleTrainer import ExtensibleTrainer
from data import create_dataset, create_dataloader
from tqdm import tqdm
import torch
import numpy as np
def forward_pass(model, output_dir, opt):
def forward_pass(model, data, output_dir, opt):
alteration_suffix = util.opt_get(opt, ['name'], '')
denorm_range = tuple(util.opt_get(opt, ['image_normalization_range'], [0, 1]))
model.feed_data(data, 0, need_GT=need_GT)
@ -47,11 +49,16 @@ def forward_pass(model, output_dir, opt):
if __name__ == "__main__":
# Set seeds
torch.manual_seed(5555)
random.seed(5555)
np.random.seed(5555)
#### options
torch.backends.cudnn.benchmark = True
want_metrics = False
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_cats_stylegan2_rosinality.yml')
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_diffusion_unet.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt
@ -93,7 +100,7 @@ if __name__ == "__main__":
need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
need_GT = need_GT and want_metrics
fea_loss, psnr_loss = forward_pass(model, dataset_dir, opt)
fea_loss, psnr_loss = forward_pass(model, data, dataset_dir, opt)
fea_loss += fea_loss
psnr_loss += psnr_loss

View File

@ -302,7 +302,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_unet_diffusion.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_resnet_cifar.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()

View File

@ -159,7 +159,8 @@ class ExtensibleTrainer(BaseModel):
self.batch_factor = self.mega_batch_factor
self.opt['checkpointing_enabled'] = self.checkpointing_cache
# The batch factor can be adjusted on a period to allow known high-memory steps to fit in GPU memory.
if 'mod_batch_factor' in self.opt['train'].keys() and \
if 'train' in self.opt.keys() and \
'mod_batch_factor' in self.opt['train'].keys() and \
self.env['step'] % self.opt['train']['mod_batch_factor_every'] == 0:
self.batch_factor = self.opt['train']['mod_batch_factor']
if self.opt['train']['mod_batch_factor_also_disable_checkpointing']:
@ -350,8 +351,7 @@ class ExtensibleTrainer(BaseModel):
def get_current_visuals(self, need_GT=True):
# Conforms to an archaic format from MMSR.
res = {'lq': self.eval_state['lq'][0].float().cpu(),
'rlt': self.eval_state[self.opt['eval']['output_state']][0].float().cpu()}
res = {'rlt': self.eval_state[self.opt['eval']['output_state']][0].float().cpu()}
if 'hq' in self.eval_state.keys():
res['hq'] = self.eval_state['hq'][0].float().cpu(),
return res

View File

@ -40,7 +40,9 @@ class GaussianDiffusionInferenceInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
self.generator = opt['generator']
self.output_shape = opt['output_shape']
self.output_batch_size = opt['output_batch_size']
self.output_scale_factor = opt['output_scale_factor']
self.undo_n1_to_1 = opt_get(opt, ['undo_n1_to_1'], False) # Explanation: when specified, will shift the output of this injector from [-1,1] to [0,1]
opt['diffusion_args']['betas'] = get_named_beta_schedule(**opt['beta_schedule'])
opt['diffusion_args']['use_timesteps'] = space_timesteps(opt['beta_schedule']['num_diffusion_timesteps'],
[opt_get(opt, ['respaced_timestep_spacing'], opt['beta_schedule']['num_diffusion_timesteps'])])
@ -49,9 +51,12 @@ class GaussianDiffusionInferenceInjector(Injector):
def forward(self, state):
gen = self.env['generators'][self.opt['generator']]
batch_size = self.output_shape[0]
model_inputs = {k: state[v][:batch_size] for k, v in self.model_input_keys.items()}
model_inputs = {k: state[v][:self.output_batch_size] for k, v in self.model_input_keys.items()}
gen.eval()
with torch.no_grad():
gen = self.diffusion.p_sample_loop(gen, self.output_shape, model_kwargs=model_inputs)
output_shape = (self.output_batch_size, 3, model_inputs['low_res'].shape[-2] * self.output_scale_factor,
model_inputs['low_res'].shape[-1] * self.output_scale_factor)
gen = self.diffusion.p_sample_loop(gen, output_shape, model_kwargs=model_inputs)
if self.undo_n1_to_1:
gen = (gen + 1) / 2
return {self.output: gen}

View File

@ -7,7 +7,7 @@ from trainer.losses import create_loss
import torch
from collections import OrderedDict
from trainer.inject import create_injector
from utils.util import recursively_detach
from utils.util import recursively_detach, opt_get
logger = logging.getLogger('base')
@ -53,21 +53,19 @@ class ConfigurableStep(Module):
# This default implementation defines a single optimizer for all Generator parameters.
# Must be called after networks are initialized and wrapped.
def define_optimizers(self):
opt_configs = opt_get(self.step_opt, ['optimizer_params'], None)
self.optimizers = []
if opt_configs is None:
return
training = self.step_opt['training']
training_net = self.get_network_for_name(training)
# When only training one network, optimizer params can just embedded in the step params.
if 'optimizer_params' not in self.step_opt.keys():
opt_configs = [self.step_opt]
else:
opt_configs = [self.step_opt['optimizer_params']]
nets = [training_net]
training = [training]
self.optimizers = []
for net_name, net, opt_config in zip(training, nets, opt_configs):
# Configs can organize parameters by-group and specify different learning rates for each group. This only
# works in the model specifically annotates which parameters belong in which group using PARAM_GROUP.
optim_params = {'default': {'params': [], 'lr': opt_config['lr']}}
if 'param_groups' in opt_config.keys():
if opt_config is not None and 'param_groups' in opt_config.keys():
for k, pg in opt_config['param_groups'].items():
optim_params[k] = {'params': [], 'lr': pg['lr']}