This commit is contained in:
James Betker 2021-04-21 18:09:09 -06:00
parent 94e069bced
commit b687ef4cd0
9 changed files with 59 additions and 366 deletions

1
.gitignore vendored
View File

@ -18,6 +18,7 @@ data/*
*.cu
*.pt
*.pth
*.pdf
# template

View File

@ -3,5 +3,5 @@
<component name="JavaScriptSettings">
<option name="languageLevel" value="ES6" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (torch_nightly)" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (torch18)" project-jdk-type="Python SDK" />
</project>

View File

@ -9,7 +9,7 @@
<excludeFolder url="file://$MODULE_DIR$/results" />
<excludeFolder url="file://$MODULE_DIR$/tb_logger" />
</content>
<orderEntry type="jdk" jdkName="Python 3.8 (torch_nightly)" jdkType="Python SDK" />
<orderEntry type="jdk" jdkName="Python 3.8 (torch18)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">

View File

@ -5,10 +5,14 @@ import random
import cv2
import kornia
import numpy as np
import pytorch_ssim
import torch
import os
import torchvision
from torch.utils.data import DataLoader
from torchvision.transforms import Normalize
from tqdm import tqdm
from data import util
# Builds a dataset created from a simple folder containing a list of training/test/validation images.
@ -140,7 +144,7 @@ class ImageFolderDataset:
if self.normalize:
hq = self.normalize(hq)
out_dict = {'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item]}
out_dict = {'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item], 'has_alt': False}
if self.fetch_alt_image:
# This works by assuming a specific filename structure as would produced by ffmpeg. ex:
@ -165,6 +169,7 @@ class ImageFolderDataset:
alt_hq = util.read_img(None, next_img, rgb=True)
alt_hs = self.resize_hq([alt_hq])
alt_hq = torch.from_numpy(np.ascontiguousarray(np.transpose(alt_hs[0], (2, 0, 1)))).float()
out_dict['has_alt'] = True
if not self.skip_lq:
for_lq.append(alt_hs[0])
except:
@ -200,33 +205,25 @@ class ImageFolderDataset:
if __name__ == '__main__':
opt = {
'name': 'amalgam',
'paths': ['F:\\4k6k\\datasets\\images\\youtube\\4k_quote_unquote\\images'],
'paths': ['E:\\4k6k\\datasets\\ns_images\\256_unsupervised'],
'weights': [1],
'target_size': 256,
'force_multiple': 32,
'force_multiple': 1,
'scale': 2,
'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'],
'random_corruptions': ['noise-5', 'none'],
'num_corrupts_per_image': 1,
'corrupt_before_downsize': False,
'corrupt_before_downsize': True,
'fetch_alt_image': True,
#'labeler': {
# 'type': 'patch_labels',
# 'label_file': 'F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\categories_new.json'
#}
'disable_flip': True,
'fixed_corruptions': [ 'jpeg-broad' ],
'num_corrupts_per_image': 0,
'corruption_blur_scale': 0
}
ds = ImageFolderDataset(opt)
ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=2)
import os
os.makedirs("debug", exist_ok=True)
for i in range(0, len(ds)):
o = ds[random.randint(0, len(ds)-1)]
hq = o['lq']
#masked = (o['labels_mask'] * .5 + .5) * hq
import torchvision
torchvision.utils.save_image(hq.unsqueeze(0), "debug/%i_lq.png" % (i,))
torchvision.utils.save_image(o['alt_lq'].unsqueeze(0), "debug/%i_lq_alt.png" % (i,))
#if len(o['labels'].unique()) > 1:
# randlbl = np.random.choice(o['labels'].unique()[1:])
# moremask = hq * ((1*(o['labels'] == randlbl))*.5+.5)
# torchvision.utils.save_image(moremask.unsqueeze(0), "debug/%i_%s.png" % (i, o['label_strings'][randlbl]))
output_path = 'E:\\4k6k\\datasets\\ns_images\\128_unsupervised'
os.makedirs(output_path, exist_ok=True)
for i, d in tqdm(enumerate(ds)):
lq = d['lq']
torchvision.utils.save_image(lq[:,:,16:-16,:], f'{output_path}\\{i+500000}.png')
if i >= 200000:
break

View File

View File

@ -208,7 +208,7 @@ def run_tsne_instance_level():
print("Run Y = tsne.tsne(X, no_dims, perplexity) to perform t-SNE on your dataset.")
limit = 4000
X, files = torch.load('results.pth')
X, files = torch.load('../results_instance_resnet.pth')
zipped = list(zip(X, files))
shuffle(zipped)
X, files = zip(*zipped)
@ -242,7 +242,7 @@ def run_tsne_instance_level():
# Uses the results from the calculation above to create a **massive** pdf plot that shows 1/8 size images on the tsne
# spectrum.
def plot_instance_level_results_as_image_graph():
Y, files = torch.load('tsne_output.pth')
Y, files = torch.load('../tsne_output.pth')
fig, ax = pyplot.subplots()
fig.set_size_inches(200,200,forward=True)
ax.update_datalim(np.column_stack([Y[:, 0], Y[:, 1]]))
@ -250,7 +250,7 @@ def plot_instance_level_results_as_image_graph():
for b in tqdm(range(Y.shape[0])):
im = pyplot.imread(files[b])
im = OffsetImage(im, zoom=1/8)
im = OffsetImage(im, zoom=1/2)
ab = AnnotationBbox(im, (Y[b, 0], Y[b, 1]), xycoords='data', frameon=False)
ax.add_artist(ab)
ax.scatter(Y[:, 0], Y[:, 1])
@ -277,7 +277,7 @@ def run_tsne_pixel_level():
'''
# For resnet-style latent tuples
X, files = torch.load('../results.pth')
X, files = torch.load('../../results/2021-4-8-imgset-latent-dict.pth')
zipped = list(zip(X, files))
shuffle(zipped)
X, files = zip(*zipped)
@ -347,8 +347,8 @@ def plot_pixel_level_results_as_image_graph():
if __name__ == "__main__":
# For use with instance-level results (e.g. from byol_resnet_playground.py)
#run_tsne_instance_level()
#plot_instance_level_results_as_image_graph()
plot_instance_level_results_as_image_graph()
# For use with pixel-level results (e.g. from byol_uresnet_playground)
#run_tsne_pixel_level()
plot_pixel_level_results_as_image_graph()
#plot_pixel_level_results_as_image_graph()

View File

@ -19,9 +19,9 @@ def main():
# compression time. If read raw images during training, use 0 for faster IO speed.
opt['dest'] = 'file'
opt['input_folder'] = ['E:\\4k6k\\datasets\\images\\faces\\CelebAMask-HQ\\CelebA-HQ-img']
opt['save_folder'] = 'E:\\4k6k\\datasets\\images\\faces\\CelebAMask-HQ\\256px'
opt['imgsize'] = 256
opt['input_folder'] = ['E:\\4k6k\\datasets\\images\\lsun\\lsun\\cats']
opt['save_folder'] = 'E:\\4k6k\\datasets\\images\\lsun\\lsun\\cats\\256_4_by_3'
opt['imgsize'] = (256,192)
opt['bottom_crop'] = 0
opt['keep_folder'] = False
@ -66,15 +66,25 @@ class TiledDataset(data.Dataset):
h, w, c = img.shape
# Uncomment to filter any image that doesnt meet a threshold size.
if min(h,w) < self.opt['imgsize']:
imgsz_w, imgsz_h = self.opt['imgsize']
if w < imgsz_w or h < imgsz_h:
print("Skipping due to threshold")
return None
# We must convert the image into a square.
dim = min(h, w)
# Crop the image so that only the center is left, since this is often the most salient part of the image.
img = img[(h - dim) // 2:dim + (h - dim) // 2, (w - dim) // 2:dim + (w - dim) // 2, :]
img = cv2.resize(img, (self.opt['imgsize'], self.opt['imgsize']), interpolation=cv2.INTER_AREA)
# We must first center-crop the image to the proper aspect ratio
aspect_ratio = imgsz_h / imgsz_w
if h < w * aspect_ratio:
hdim = h
wdim = int(h / aspect_ratio)
elif w * aspect_ratio < h:
hdim = int(w * aspect_ratio)
wdim = w
else:
hdim = h
wdim = w
img = img[(h - hdim) // 2:hdim + (h - hdim) // 2, (w - wdim) // 2:wdim + (w - wdim) // 2, :]
img = cv2.resize(img, (imgsz_w, imgsz_h), interpolation=cv2.INTER_AREA)
output_folder = self.opt['save_folder']
if self.opt['keep_folder']:
# Attempt to find the folder name one level above opt['input_folder'] and use that.

View File

@ -1,325 +0,0 @@
import os
import math
import argparse
import random
import logging
from tqdm import tqdm
import torch
from data.data_sampler import DistIterSampler
from trainer.eval.evaluator import create_evaluator
from utils import util, options as option
from data import create_dataloader, create_dataset
from trainer.ExtensibleTrainer import ExtensibleTrainer
from time import time
def init_dist(backend, **kwargs):
# These packages have globals that screw with Windows, so only import them if needed.
import torch.distributed as dist
import torch.multiprocessing as mp
"""initialization for distributed training"""
if mp.get_start_method(allow_none=True) != 'spawn':
mp.set_start_method('spawn')
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
class Trainer:
def init(self, opt, launcher, all_networks={}):
self._profile = False
self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'].keys() else True
self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'].keys() else True
#### loading resume state if exists
if opt['path'].get('resume_state', None):
# distributed resuming: all load into default GPU
device_id = torch.cuda.current_device()
resume_state = torch.load(opt['path']['resume_state'],
map_location=lambda storage, loc: storage.cuda(device_id))
option.check_resume(opt, resume_state['iter']) # check resume options
else:
resume_state = None
#### mkdir and loggers
if self.rank <= 0: # normal training (self.rank -1) OR distributed training (self.rank 0)
if resume_state is None:
util.mkdir_and_rename(
opt['path']['experiments_root']) # rename experiment folder if exists
util.mkdirs(
(path for key, path in opt['path'].items() if not key == 'experiments_root' and path is not None
and 'pretrain_model' not in key and 'resume' not in key))
# config loggers. Before it, the log will not work
util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
screen=True, tofile=True)
self.logger = logging.getLogger('base')
self.logger.info(option.dict2str(opt))
# tensorboard logger
if opt['use_tb_logger'] and 'debug' not in opt['name']:
self.tb_logger_path = os.path.join(opt['path']['experiments_root'], 'tb_logger')
version = float(torch.__version__[0:3])
if version >= 1.1: # PyTorch 1.1
from torch.utils.tensorboard import SummaryWriter
else:
self.self.logger.info(
'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
from tensorboardX import SummaryWriter
self.tb_logger = SummaryWriter(log_dir=self.tb_logger_path)
else:
util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
self.logger = logging.getLogger('base')
# convert to NoneDict, which returns None for missing keys
opt = option.dict_to_nonedict(opt)
self.opt = opt
#### wandb init
if opt['wandb']:
import wandb
os.makedirs(os.path.join(opt['path']['log'], 'wandb'), exist_ok=True)
wandb.init(project=opt['name'], dir=opt['path']['log'])
#### random seed
seed = opt['train']['manual_seed']
if seed is None:
seed = random.randint(1, 10000)
if self.rank <= 0:
self.logger.info('Random seed: {}'.format(seed))
seed += self.rank # Different multiprocessing instances should behave differently.
util.set_random_seed(seed)
torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = True
# torch.autograd.set_detect_anomaly(True)
# Save the compiled opt dict to the global loaded_options variable.
util.loaded_options = opt
#### create train and val dataloader
dataset_ratio = 1 # enlarge the size of each epoch
for phase, dataset_opt in opt['datasets'].items():
if phase == 'train':
self.train_set = create_dataset(dataset_opt)
train_size = int(math.ceil(len(self.train_set) / dataset_opt['batch_size']))
total_iters = int(opt['train']['niter'])
self.total_epochs = int(math.ceil(total_iters / train_size))
if opt['dist']:
self.train_sampler = DistIterSampler(self.train_set, self.world_size, self.rank, dataset_ratio)
self.total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
else:
self.train_sampler = None
self.train_loader = create_dataloader(self.train_set, dataset_opt, opt, self.train_sampler)
if self.rank <= 0:
self.logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
len(self.train_set), train_size))
self.logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
self.total_epochs, total_iters))
elif phase == 'val':
self.val_set = create_dataset(dataset_opt)
self.val_loader = create_dataloader(self.val_set, dataset_opt, opt, None)
if self.rank <= 0:
self.logger.info('Number of val images in [{:s}]: {:d}'.format(
dataset_opt['name'], len(self.val_set)))
else:
raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
assert self.train_loader is not None
#### create model
self.model = ExtensibleTrainer(opt, cached_networks=all_networks)
### Evaluators
self.evaluators = []
if 'evaluators' in opt['eval'].keys():
for ev_key, ev_opt in opt['eval']['evaluators'].items():
self.evaluators.append(create_evaluator(self.model.networks[ev_opt['for']],
ev_opt, self.model.env))
#### resume training
if resume_state:
self.logger.info('Resuming training from epoch: {}, iter: {}.'.format(
resume_state['epoch'], resume_state['iter']))
self.start_epoch = resume_state['epoch']
self.current_step = resume_state['iter']
self.model.resume_training(resume_state, 'amp_opt_level' in opt.keys()) # handle optimizers and schedulers
else:
self.current_step = -1 if 'start_step' not in opt.keys() else opt['start_step']
self.start_epoch = 0
if 'force_start_step' in opt.keys():
self.current_step = opt['force_start_step']
opt['current_step'] = self.current_step
def do_step(self, train_data):
if self._profile:
print("Data fetch: %f" % (time() - _t))
_t = time()
opt = self.opt
self.current_step += 1
#### update learning rate
self.model.update_learning_rate(self.current_step, warmup_iter=opt['train']['warmup_iter'])
#### training
if self._profile:
print("Update LR: %f" % (time() - _t))
_t = time()
self.model.feed_data(train_data, self.current_step)
self.model.optimize_parameters(self.current_step)
if self._profile:
print("Model feed + step: %f" % (time() - _t))
_t = time()
#### log
if self.current_step % opt['logger']['print_freq'] == 0 and self.rank <= 0:
logs = self.model.get_current_log(self.current_step)
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(self.epoch, self.current_step)
for v in self.model.get_current_learning_rate():
message += '{:.3e},'.format(v)
message += ')] '
for k, v in logs.items():
if 'histogram' in k:
self.tb_logger.add_histogram(k, v, self.current_step)
elif isinstance(v, dict):
self.tb_logger.add_scalars(k, v, self.current_step)
else:
message += '{:s}: {:.4e} '.format(k, v)
# tensorboard logger
if opt['use_tb_logger'] and 'debug' not in opt['name']:
self.tb_logger.add_scalar(k, v, self.current_step)
if opt['wandb']:
import wandb
wandb.log(logs)
self.logger.info(message)
#### save models and training states
if self.current_step % opt['logger']['save_checkpoint_freq'] == 0:
if self.rank <= 0:
self.logger.info('Saving models and training states.')
self.model.save(self.current_step)
self.model.save_training_state(self.epoch, self.current_step)
if 'alt_path' in opt['path'].keys():
import shutil
print("Synchronizing tb_logger to alt_path..")
alt_tblogger = os.path.join(opt['path']['alt_path'], "tb_logger")
shutil.rmtree(alt_tblogger, ignore_errors=True)
shutil.copytree(self.tb_logger_path, alt_tblogger)
#### validation
if opt['datasets'].get('val', None) and self.current_step % opt['train']['val_freq'] == 0:
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan',
'extensibletrainer'] and self.rank <= 0: # image restoration validation
avg_psnr = 0.
avg_fea_loss = 0.
idx = 0
val_tqdm = tqdm(self.val_loader)
for val_data in val_tqdm:
idx += 1
for b in range(len(val_data['HQ_path'])):
img_name = os.path.splitext(os.path.basename(val_data['HQ_path'][b]))[0]
img_dir = os.path.join(opt['path']['val_images'], img_name)
util.mkdir(img_dir)
self.model.feed_data(val_data, self.current_step)
self.model.test()
visuals = self.model.get_current_visuals()
if visuals is None:
continue
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
# calculate PSNR
if self.val_compute_psnr:
gt_img = util.tensor2img(visuals['hq'][b]) # uint8
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
avg_psnr += util.calculate_psnr(sr_img, gt_img)
# calculate fea loss
if self.val_compute_fea:
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['hq'][b])
# Save SR images for reference
img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step)
save_img_path = os.path.join(img_dir, img_base_name)
util.save_img(sr_img, save_img_path)
avg_psnr = avg_psnr / idx
avg_fea_loss = avg_fea_loss / idx
# log
self.logger.info('# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(avg_psnr, avg_fea_loss))
# tensorboard logger
if opt['use_tb_logger'] and 'debug' not in opt['name'] and self.rank <= 0:
self.tb_logger.add_scalar('val_psnr', avg_psnr, self.current_step)
self.tb_logger.add_scalar('val_fea', avg_fea_loss, self.current_step)
if len(self.evaluators) != 0 and self.current_step % opt['train']['val_freq'] == 0 and self.rank <= 0:
eval_dict = {}
for eval in self.evaluators:
eval_dict.update(eval.perform_eval())
if self.rank <= 0:
print("Evaluator results: ", eval_dict)
for ek, ev in eval_dict.items():
self.tb_logger.add_scalar(ek, ev, self.current_step)
def do_training(self):
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
for epoch in range(self.start_epoch, self.total_epochs + 1):
self.epoch = epoch
if opt['dist']:
self.train_sampler.set_epoch(epoch)
tq_ldr = tqdm(self.train_loader)
_t = time()
for train_data in tq_ldr:
self.do_step(train_data)
def create_training_generator(self, index):
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
for epoch in range(self.start_epoch, self.total_epochs + 1):
self.epoch = epoch
if self.opt['dist']:
self.train_sampler.set_epoch(epoch)
tq_ldr = tqdm(self.train_loader, position=index)
_t = time()
for train_data in tq_ldr:
yield self.model
self.do_step(train_data)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixpro_3.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)
if args.launcher != 'none':
# export CUDA_VISIBLE_DEVICES for running in distributed mode.
if 'gpu_ids' in opt.keys():
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
trainer = Trainer()
#### distributed training settings
if args.launcher == 'none': # disabled distributed training
opt['dist'] = False
trainer.rank = -1
if len(opt['gpu_ids']) == 1:
torch.cuda.set_device(opt['gpu_ids'][0])
print('Disabled distributed training.')
else:
opt['dist'] = True
init_dist('nccl')
trainer.world_size = torch.distributed.get_world_size()
trainer.rank = torch.distributed.get_rank()
trainer.init(opt, args.launcher)
trainer.do_training()

View File

@ -0,0 +1,10 @@
# StyleGAN Implementations
DLAS supports two different StyleGAN2 implementations:
- [@rosinality implementation](https://github.com/rosinality/stylegan2-pytorch/commits/master)
Designed to reach parity with the nVidia reference implementation in TF1.5
- [@lucidrains implementation](https://github.com/lucidrains/stylegan2-pytorch)
Designed with simplicity and readability in mind.
I prefer the readability of @lucidrains implementation, but you cannot (yet) use pretrained weights
with it. I'm working on that.