Various changes to fix testing
This commit is contained in:
parent
220f11a5e4
commit
65c474eecf
|
@ -3,18 +3,20 @@ import logging
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
|
|
||||||
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
|
def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
|
||||||
phase = dataset_opt['phase']
|
phase = dataset_opt['phase']
|
||||||
if phase == 'train':
|
if phase == 'train':
|
||||||
if opt['dist']:
|
if opt_get(opt, ['dist'], False):
|
||||||
world_size = torch.distributed.get_world_size()
|
world_size = torch.distributed.get_world_size()
|
||||||
num_workers = dataset_opt['n_workers']
|
num_workers = dataset_opt['n_workers']
|
||||||
assert dataset_opt['batch_size'] % world_size == 0
|
assert dataset_opt['batch_size'] % world_size == 0
|
||||||
batch_size = dataset_opt['batch_size'] // world_size
|
batch_size = dataset_opt['batch_size'] // world_size
|
||||||
shuffle = False
|
shuffle = False
|
||||||
else:
|
else:
|
||||||
num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids'])
|
num_workers = dataset_opt['n_workers']
|
||||||
batch_size = dataset_opt['batch_size']
|
batch_size = dataset_opt['batch_size']
|
||||||
shuffle = True
|
shuffle = True
|
||||||
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
|
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
|
||||||
|
|
|
@ -37,7 +37,10 @@ class ImageCorruptor:
|
||||||
|
|
||||||
def corrupt_images(self, imgs, return_entropy=False):
|
def corrupt_images(self, imgs, return_entropy=False):
|
||||||
if self.num_corrupts == 0 and not self.fixed_corruptions:
|
if self.num_corrupts == 0 and not self.fixed_corruptions:
|
||||||
return imgs
|
if return_entropy:
|
||||||
|
return imgs, []
|
||||||
|
else:
|
||||||
|
return imgs
|
||||||
|
|
||||||
if self.num_corrupts == 0:
|
if self.num_corrupts == 0:
|
||||||
augmentations = []
|
augmentations = []
|
||||||
|
|
|
@ -35,6 +35,8 @@ class ImageFolderDataset:
|
||||||
self.skip_lq = opt_get(opt, ['skip_lq'], False)
|
self.skip_lq = opt_get(opt, ['skip_lq'], False)
|
||||||
self.disable_flip = opt_get(opt, ['disable_flip'], False)
|
self.disable_flip = opt_get(opt, ['disable_flip'], False)
|
||||||
self.rgb_n1_to_1 = opt_get(opt, ['rgb_n1_to_1'], False)
|
self.rgb_n1_to_1 = opt_get(opt, ['rgb_n1_to_1'], False)
|
||||||
|
self.force_square = opt_get(opt, ['force_square'], True)
|
||||||
|
self.fixed_parameters = {k: torch.tensor(v) for k, v in opt_get(opt, ['fixed_parameters'], {}).items()}
|
||||||
if 'normalize' in opt.keys():
|
if 'normalize' in opt.keys():
|
||||||
if opt['normalize'] == 'stylegan2_norm':
|
if opt['normalize'] == 'stylegan2_norm':
|
||||||
self.normalize = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
self.normalize = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||||
|
@ -44,7 +46,8 @@ class ImageFolderDataset:
|
||||||
raise Exception('Unsupported normalize')
|
raise Exception('Unsupported normalize')
|
||||||
else:
|
else:
|
||||||
self.normalize = None
|
self.normalize = None
|
||||||
assert (self.target_hq_size // self.scale) % self.multiple == 0 # If we dont throw here, we get some really obscure errors.
|
if self.target_hq_size is not None:
|
||||||
|
assert (self.target_hq_size // self.scale) % self.multiple == 0 # If we dont throw here, we get some really obscure errors.
|
||||||
if not isinstance(self.paths, list):
|
if not isinstance(self.paths, list):
|
||||||
self.paths = [self.paths]
|
self.paths = [self.paths]
|
||||||
self.weights = [1]
|
self.weights = [1]
|
||||||
|
@ -129,10 +132,10 @@ class ImageFolderDataset:
|
||||||
if not self.disable_flip and random.random() < .5:
|
if not self.disable_flip and random.random() < .5:
|
||||||
hq = hq[:, ::-1, :]
|
hq = hq[:, ::-1, :]
|
||||||
|
|
||||||
# We must convert the image into a square.
|
if self.force_square:
|
||||||
h, w, _ = hq.shape
|
h, w, _ = hq.shape
|
||||||
dim = min(h, w)
|
dim = min(h, w)
|
||||||
hq = hq[(h - dim) // 2:dim + (h - dim) // 2, (w - dim) // 2:dim + (w - dim) // 2, :]
|
hq = hq[(h - dim) // 2:dim + (h - dim) // 2, (w - dim) // 2:dim + (w - dim) // 2, :]
|
||||||
|
|
||||||
if self.labeler:
|
if self.labeler:
|
||||||
assert hq.shape[0] == hq.shape[1] # This just has not been accomodated yet.
|
assert hq.shape[0] == hq.shape[1] # This just has not been accomodated yet.
|
||||||
|
@ -211,6 +214,7 @@ class ImageFolderDataset:
|
||||||
v = v * 2 - 1
|
v = v * 2 - 1
|
||||||
out_dict[k] = v
|
out_dict[k] = v
|
||||||
|
|
||||||
|
out_dict.update(self.fixed_parameters)
|
||||||
return out_dict
|
return out_dict
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -241,8 +241,8 @@ class BYOL(nn.Module):
|
||||||
torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,)))
|
torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,)))
|
||||||
|
|
||||||
def forward(self, image_one, image_two):
|
def forward(self, image_one, image_two):
|
||||||
image_one = self.aug(image_one)
|
image_one = self.aug(image_one.clone())
|
||||||
image_two = self.aug(image_two)
|
image_two = self.aug(image_two.clone())
|
||||||
|
|
||||||
# Keep copies on hand for visual_dbg.
|
# Keep copies on hand for visual_dbg.
|
||||||
self.im1 = image_one.detach().clone()
|
self.im1 = image_one.detach().clone()
|
||||||
|
|
0
codes/scripts/diffusion/diffusion_sampler.py
Normal file
0
codes/scripts/diffusion/diffusion_sampler.py
Normal file
|
@ -1,5 +1,6 @@
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import logging
|
import logging
|
||||||
|
import random
|
||||||
import time
|
import time
|
||||||
import argparse
|
import argparse
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
@ -11,9 +12,10 @@ from trainer.ExtensibleTrainer import ExtensibleTrainer
|
||||||
from data import create_dataset, create_dataloader
|
from data import create_dataset, create_dataloader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def forward_pass(model, output_dir, opt):
|
def forward_pass(model, data, output_dir, opt):
|
||||||
alteration_suffix = util.opt_get(opt, ['name'], '')
|
alteration_suffix = util.opt_get(opt, ['name'], '')
|
||||||
denorm_range = tuple(util.opt_get(opt, ['image_normalization_range'], [0, 1]))
|
denorm_range = tuple(util.opt_get(opt, ['image_normalization_range'], [0, 1]))
|
||||||
model.feed_data(data, 0, need_GT=need_GT)
|
model.feed_data(data, 0, need_GT=need_GT)
|
||||||
|
@ -47,11 +49,16 @@ def forward_pass(model, output_dir, opt):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
# Set seeds
|
||||||
|
torch.manual_seed(5555)
|
||||||
|
random.seed(5555)
|
||||||
|
np.random.seed(5555)
|
||||||
|
|
||||||
#### options
|
#### options
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
want_metrics = False
|
want_metrics = False
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_cats_stylegan2_rosinality.yml')
|
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_diffusion_unet.yml')
|
||||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||||
opt = option.dict_to_nonedict(opt)
|
opt = option.dict_to_nonedict(opt)
|
||||||
utils.util.loaded_options = opt
|
utils.util.loaded_options = opt
|
||||||
|
@ -93,7 +100,7 @@ if __name__ == "__main__":
|
||||||
need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
|
need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
|
||||||
need_GT = need_GT and want_metrics
|
need_GT = need_GT and want_metrics
|
||||||
|
|
||||||
fea_loss, psnr_loss = forward_pass(model, dataset_dir, opt)
|
fea_loss, psnr_loss = forward_pass(model, data, dataset_dir, opt)
|
||||||
fea_loss += fea_loss
|
fea_loss += fea_loss
|
||||||
psnr_loss += psnr_loss
|
psnr_loss += psnr_loss
|
||||||
|
|
||||||
|
|
|
@ -302,7 +302,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_unet_diffusion.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_resnet_cifar.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -159,7 +159,8 @@ class ExtensibleTrainer(BaseModel):
|
||||||
self.batch_factor = self.mega_batch_factor
|
self.batch_factor = self.mega_batch_factor
|
||||||
self.opt['checkpointing_enabled'] = self.checkpointing_cache
|
self.opt['checkpointing_enabled'] = self.checkpointing_cache
|
||||||
# The batch factor can be adjusted on a period to allow known high-memory steps to fit in GPU memory.
|
# The batch factor can be adjusted on a period to allow known high-memory steps to fit in GPU memory.
|
||||||
if 'mod_batch_factor' in self.opt['train'].keys() and \
|
if 'train' in self.opt.keys() and \
|
||||||
|
'mod_batch_factor' in self.opt['train'].keys() and \
|
||||||
self.env['step'] % self.opt['train']['mod_batch_factor_every'] == 0:
|
self.env['step'] % self.opt['train']['mod_batch_factor_every'] == 0:
|
||||||
self.batch_factor = self.opt['train']['mod_batch_factor']
|
self.batch_factor = self.opt['train']['mod_batch_factor']
|
||||||
if self.opt['train']['mod_batch_factor_also_disable_checkpointing']:
|
if self.opt['train']['mod_batch_factor_also_disable_checkpointing']:
|
||||||
|
@ -350,8 +351,7 @@ class ExtensibleTrainer(BaseModel):
|
||||||
|
|
||||||
def get_current_visuals(self, need_GT=True):
|
def get_current_visuals(self, need_GT=True):
|
||||||
# Conforms to an archaic format from MMSR.
|
# Conforms to an archaic format from MMSR.
|
||||||
res = {'lq': self.eval_state['lq'][0].float().cpu(),
|
res = {'rlt': self.eval_state[self.opt['eval']['output_state']][0].float().cpu()}
|
||||||
'rlt': self.eval_state[self.opt['eval']['output_state']][0].float().cpu()}
|
|
||||||
if 'hq' in self.eval_state.keys():
|
if 'hq' in self.eval_state.keys():
|
||||||
res['hq'] = self.eval_state['hq'][0].float().cpu(),
|
res['hq'] = self.eval_state['hq'][0].float().cpu(),
|
||||||
return res
|
return res
|
||||||
|
|
|
@ -40,7 +40,9 @@ class GaussianDiffusionInferenceInjector(Injector):
|
||||||
def __init__(self, opt, env):
|
def __init__(self, opt, env):
|
||||||
super().__init__(opt, env)
|
super().__init__(opt, env)
|
||||||
self.generator = opt['generator']
|
self.generator = opt['generator']
|
||||||
self.output_shape = opt['output_shape']
|
self.output_batch_size = opt['output_batch_size']
|
||||||
|
self.output_scale_factor = opt['output_scale_factor']
|
||||||
|
self.undo_n1_to_1 = opt_get(opt, ['undo_n1_to_1'], False) # Explanation: when specified, will shift the output of this injector from [-1,1] to [0,1]
|
||||||
opt['diffusion_args']['betas'] = get_named_beta_schedule(**opt['beta_schedule'])
|
opt['diffusion_args']['betas'] = get_named_beta_schedule(**opt['beta_schedule'])
|
||||||
opt['diffusion_args']['use_timesteps'] = space_timesteps(opt['beta_schedule']['num_diffusion_timesteps'],
|
opt['diffusion_args']['use_timesteps'] = space_timesteps(opt['beta_schedule']['num_diffusion_timesteps'],
|
||||||
[opt_get(opt, ['respaced_timestep_spacing'], opt['beta_schedule']['num_diffusion_timesteps'])])
|
[opt_get(opt, ['respaced_timestep_spacing'], opt['beta_schedule']['num_diffusion_timesteps'])])
|
||||||
|
@ -49,9 +51,12 @@ class GaussianDiffusionInferenceInjector(Injector):
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
gen = self.env['generators'][self.opt['generator']]
|
gen = self.env['generators'][self.opt['generator']]
|
||||||
batch_size = self.output_shape[0]
|
model_inputs = {k: state[v][:self.output_batch_size] for k, v in self.model_input_keys.items()}
|
||||||
model_inputs = {k: state[v][:batch_size] for k, v in self.model_input_keys.items()}
|
|
||||||
gen.eval()
|
gen.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
gen = self.diffusion.p_sample_loop(gen, self.output_shape, model_kwargs=model_inputs)
|
output_shape = (self.output_batch_size, 3, model_inputs['low_res'].shape[-2] * self.output_scale_factor,
|
||||||
|
model_inputs['low_res'].shape[-1] * self.output_scale_factor)
|
||||||
|
gen = self.diffusion.p_sample_loop(gen, output_shape, model_kwargs=model_inputs)
|
||||||
|
if self.undo_n1_to_1:
|
||||||
|
gen = (gen + 1) / 2
|
||||||
return {self.output: gen}
|
return {self.output: gen}
|
||||||
|
|
|
@ -7,7 +7,7 @@ from trainer.losses import create_loss
|
||||||
import torch
|
import torch
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from trainer.inject import create_injector
|
from trainer.inject import create_injector
|
||||||
from utils.util import recursively_detach
|
from utils.util import recursively_detach, opt_get
|
||||||
|
|
||||||
logger = logging.getLogger('base')
|
logger = logging.getLogger('base')
|
||||||
|
|
||||||
|
@ -53,21 +53,19 @@ class ConfigurableStep(Module):
|
||||||
# This default implementation defines a single optimizer for all Generator parameters.
|
# This default implementation defines a single optimizer for all Generator parameters.
|
||||||
# Must be called after networks are initialized and wrapped.
|
# Must be called after networks are initialized and wrapped.
|
||||||
def define_optimizers(self):
|
def define_optimizers(self):
|
||||||
|
opt_configs = opt_get(self.step_opt, ['optimizer_params'], None)
|
||||||
|
self.optimizers = []
|
||||||
|
if opt_configs is None:
|
||||||
|
return
|
||||||
training = self.step_opt['training']
|
training = self.step_opt['training']
|
||||||
training_net = self.get_network_for_name(training)
|
training_net = self.get_network_for_name(training)
|
||||||
# When only training one network, optimizer params can just embedded in the step params.
|
|
||||||
if 'optimizer_params' not in self.step_opt.keys():
|
|
||||||
opt_configs = [self.step_opt]
|
|
||||||
else:
|
|
||||||
opt_configs = [self.step_opt['optimizer_params']]
|
|
||||||
nets = [training_net]
|
nets = [training_net]
|
||||||
training = [training]
|
training = [training]
|
||||||
self.optimizers = []
|
|
||||||
for net_name, net, opt_config in zip(training, nets, opt_configs):
|
for net_name, net, opt_config in zip(training, nets, opt_configs):
|
||||||
# Configs can organize parameters by-group and specify different learning rates for each group. This only
|
# Configs can organize parameters by-group and specify different learning rates for each group. This only
|
||||||
# works in the model specifically annotates which parameters belong in which group using PARAM_GROUP.
|
# works in the model specifically annotates which parameters belong in which group using PARAM_GROUP.
|
||||||
optim_params = {'default': {'params': [], 'lr': opt_config['lr']}}
|
optim_params = {'default': {'params': [], 'lr': opt_config['lr']}}
|
||||||
if 'param_groups' in opt_config.keys():
|
if opt_config is not None and 'param_groups' in opt_config.keys():
|
||||||
for k, pg in opt_config['param_groups'].items():
|
for k, pg in opt_config['param_groups'].items():
|
||||||
optim_params[k] = {'params': [], 'lr': pg['lr']}
|
optim_params[k] = {'params': [], 'lr': pg['lr']}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user