Various changes to fix testing
This commit is contained in:
parent
220f11a5e4
commit
65c474eecf
|
@ -3,18 +3,20 @@ import logging
|
|||
import torch
|
||||
import torch.utils.data
|
||||
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
|
||||
phase = dataset_opt['phase']
|
||||
if phase == 'train':
|
||||
if opt['dist']:
|
||||
if opt_get(opt, ['dist'], False):
|
||||
world_size = torch.distributed.get_world_size()
|
||||
num_workers = dataset_opt['n_workers']
|
||||
assert dataset_opt['batch_size'] % world_size == 0
|
||||
batch_size = dataset_opt['batch_size'] // world_size
|
||||
shuffle = False
|
||||
else:
|
||||
num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids'])
|
||||
num_workers = dataset_opt['n_workers']
|
||||
batch_size = dataset_opt['batch_size']
|
||||
shuffle = True
|
||||
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
|
||||
|
|
|
@ -37,6 +37,9 @@ class ImageCorruptor:
|
|||
|
||||
def corrupt_images(self, imgs, return_entropy=False):
|
||||
if self.num_corrupts == 0 and not self.fixed_corruptions:
|
||||
if return_entropy:
|
||||
return imgs, []
|
||||
else:
|
||||
return imgs
|
||||
|
||||
if self.num_corrupts == 0:
|
||||
|
|
|
@ -35,6 +35,8 @@ class ImageFolderDataset:
|
|||
self.skip_lq = opt_get(opt, ['skip_lq'], False)
|
||||
self.disable_flip = opt_get(opt, ['disable_flip'], False)
|
||||
self.rgb_n1_to_1 = opt_get(opt, ['rgb_n1_to_1'], False)
|
||||
self.force_square = opt_get(opt, ['force_square'], True)
|
||||
self.fixed_parameters = {k: torch.tensor(v) for k, v in opt_get(opt, ['fixed_parameters'], {}).items()}
|
||||
if 'normalize' in opt.keys():
|
||||
if opt['normalize'] == 'stylegan2_norm':
|
||||
self.normalize = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||
|
@ -44,6 +46,7 @@ class ImageFolderDataset:
|
|||
raise Exception('Unsupported normalize')
|
||||
else:
|
||||
self.normalize = None
|
||||
if self.target_hq_size is not None:
|
||||
assert (self.target_hq_size // self.scale) % self.multiple == 0 # If we dont throw here, we get some really obscure errors.
|
||||
if not isinstance(self.paths, list):
|
||||
self.paths = [self.paths]
|
||||
|
@ -129,7 +132,7 @@ class ImageFolderDataset:
|
|||
if not self.disable_flip and random.random() < .5:
|
||||
hq = hq[:, ::-1, :]
|
||||
|
||||
# We must convert the image into a square.
|
||||
if self.force_square:
|
||||
h, w, _ = hq.shape
|
||||
dim = min(h, w)
|
||||
hq = hq[(h - dim) // 2:dim + (h - dim) // 2, (w - dim) // 2:dim + (w - dim) // 2, :]
|
||||
|
@ -211,6 +214,7 @@ class ImageFolderDataset:
|
|||
v = v * 2 - 1
|
||||
out_dict[k] = v
|
||||
|
||||
out_dict.update(self.fixed_parameters)
|
||||
return out_dict
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -241,8 +241,8 @@ class BYOL(nn.Module):
|
|||
torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,)))
|
||||
|
||||
def forward(self, image_one, image_two):
|
||||
image_one = self.aug(image_one)
|
||||
image_two = self.aug(image_two)
|
||||
image_one = self.aug(image_one.clone())
|
||||
image_two = self.aug(image_two.clone())
|
||||
|
||||
# Keep copies on hand for visual_dbg.
|
||||
self.im1 = image_one.detach().clone()
|
||||
|
|
0
codes/scripts/diffusion/diffusion_sampler.py
Normal file
0
codes/scripts/diffusion/diffusion_sampler.py
Normal file
|
@ -1,5 +1,6 @@
|
|||
import os.path as osp
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
import argparse
|
||||
from collections import OrderedDict
|
||||
|
@ -11,9 +12,10 @@ from trainer.ExtensibleTrainer import ExtensibleTrainer
|
|||
from data import create_dataset, create_dataloader
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def forward_pass(model, output_dir, opt):
|
||||
def forward_pass(model, data, output_dir, opt):
|
||||
alteration_suffix = util.opt_get(opt, ['name'], '')
|
||||
denorm_range = tuple(util.opt_get(opt, ['image_normalization_range'], [0, 1]))
|
||||
model.feed_data(data, 0, need_GT=need_GT)
|
||||
|
@ -47,11 +49,16 @@ def forward_pass(model, output_dir, opt):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Set seeds
|
||||
torch.manual_seed(5555)
|
||||
random.seed(5555)
|
||||
np.random.seed(5555)
|
||||
|
||||
#### options
|
||||
torch.backends.cudnn.benchmark = True
|
||||
want_metrics = False
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_cats_stylegan2_rosinality.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_diffusion_unet.yml')
|
||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
utils.util.loaded_options = opt
|
||||
|
@ -93,7 +100,7 @@ if __name__ == "__main__":
|
|||
need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
|
||||
need_GT = need_GT and want_metrics
|
||||
|
||||
fea_loss, psnr_loss = forward_pass(model, dataset_dir, opt)
|
||||
fea_loss, psnr_loss = forward_pass(model, data, dataset_dir, opt)
|
||||
fea_loss += fea_loss
|
||||
psnr_loss += psnr_loss
|
||||
|
||||
|
|
|
@ -302,7 +302,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_unet_diffusion.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_resnet_cifar.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -159,7 +159,8 @@ class ExtensibleTrainer(BaseModel):
|
|||
self.batch_factor = self.mega_batch_factor
|
||||
self.opt['checkpointing_enabled'] = self.checkpointing_cache
|
||||
# The batch factor can be adjusted on a period to allow known high-memory steps to fit in GPU memory.
|
||||
if 'mod_batch_factor' in self.opt['train'].keys() and \
|
||||
if 'train' in self.opt.keys() and \
|
||||
'mod_batch_factor' in self.opt['train'].keys() and \
|
||||
self.env['step'] % self.opt['train']['mod_batch_factor_every'] == 0:
|
||||
self.batch_factor = self.opt['train']['mod_batch_factor']
|
||||
if self.opt['train']['mod_batch_factor_also_disable_checkpointing']:
|
||||
|
@ -350,8 +351,7 @@ class ExtensibleTrainer(BaseModel):
|
|||
|
||||
def get_current_visuals(self, need_GT=True):
|
||||
# Conforms to an archaic format from MMSR.
|
||||
res = {'lq': self.eval_state['lq'][0].float().cpu(),
|
||||
'rlt': self.eval_state[self.opt['eval']['output_state']][0].float().cpu()}
|
||||
res = {'rlt': self.eval_state[self.opt['eval']['output_state']][0].float().cpu()}
|
||||
if 'hq' in self.eval_state.keys():
|
||||
res['hq'] = self.eval_state['hq'][0].float().cpu(),
|
||||
return res
|
||||
|
|
|
@ -40,7 +40,9 @@ class GaussianDiffusionInferenceInjector(Injector):
|
|||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
self.generator = opt['generator']
|
||||
self.output_shape = opt['output_shape']
|
||||
self.output_batch_size = opt['output_batch_size']
|
||||
self.output_scale_factor = opt['output_scale_factor']
|
||||
self.undo_n1_to_1 = opt_get(opt, ['undo_n1_to_1'], False) # Explanation: when specified, will shift the output of this injector from [-1,1] to [0,1]
|
||||
opt['diffusion_args']['betas'] = get_named_beta_schedule(**opt['beta_schedule'])
|
||||
opt['diffusion_args']['use_timesteps'] = space_timesteps(opt['beta_schedule']['num_diffusion_timesteps'],
|
||||
[opt_get(opt, ['respaced_timestep_spacing'], opt['beta_schedule']['num_diffusion_timesteps'])])
|
||||
|
@ -49,9 +51,12 @@ class GaussianDiffusionInferenceInjector(Injector):
|
|||
|
||||
def forward(self, state):
|
||||
gen = self.env['generators'][self.opt['generator']]
|
||||
batch_size = self.output_shape[0]
|
||||
model_inputs = {k: state[v][:batch_size] for k, v in self.model_input_keys.items()}
|
||||
model_inputs = {k: state[v][:self.output_batch_size] for k, v in self.model_input_keys.items()}
|
||||
gen.eval()
|
||||
with torch.no_grad():
|
||||
gen = self.diffusion.p_sample_loop(gen, self.output_shape, model_kwargs=model_inputs)
|
||||
output_shape = (self.output_batch_size, 3, model_inputs['low_res'].shape[-2] * self.output_scale_factor,
|
||||
model_inputs['low_res'].shape[-1] * self.output_scale_factor)
|
||||
gen = self.diffusion.p_sample_loop(gen, output_shape, model_kwargs=model_inputs)
|
||||
if self.undo_n1_to_1:
|
||||
gen = (gen + 1) / 2
|
||||
return {self.output: gen}
|
||||
|
|
|
@ -7,7 +7,7 @@ from trainer.losses import create_loss
|
|||
import torch
|
||||
from collections import OrderedDict
|
||||
from trainer.inject import create_injector
|
||||
from utils.util import recursively_detach
|
||||
from utils.util import recursively_detach, opt_get
|
||||
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
|
@ -53,21 +53,19 @@ class ConfigurableStep(Module):
|
|||
# This default implementation defines a single optimizer for all Generator parameters.
|
||||
# Must be called after networks are initialized and wrapped.
|
||||
def define_optimizers(self):
|
||||
opt_configs = opt_get(self.step_opt, ['optimizer_params'], None)
|
||||
self.optimizers = []
|
||||
if opt_configs is None:
|
||||
return
|
||||
training = self.step_opt['training']
|
||||
training_net = self.get_network_for_name(training)
|
||||
# When only training one network, optimizer params can just embedded in the step params.
|
||||
if 'optimizer_params' not in self.step_opt.keys():
|
||||
opt_configs = [self.step_opt]
|
||||
else:
|
||||
opt_configs = [self.step_opt['optimizer_params']]
|
||||
nets = [training_net]
|
||||
training = [training]
|
||||
self.optimizers = []
|
||||
for net_name, net, opt_config in zip(training, nets, opt_configs):
|
||||
# Configs can organize parameters by-group and specify different learning rates for each group. This only
|
||||
# works in the model specifically annotates which parameters belong in which group using PARAM_GROUP.
|
||||
optim_params = {'default': {'params': [], 'lr': opt_config['lr']}}
|
||||
if 'param_groups' in opt_config.keys():
|
||||
if opt_config is not None and 'param_groups' in opt_config.keys():
|
||||
for k, pg in opt_config['param_groups'].items():
|
||||
optim_params[k] = {'params': [], 'lr': pg['lr']}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user