This commit is contained in:
James Betker 2021-04-21 18:09:09 -06:00
parent 94e069bced
commit b687ef4cd0
9 changed files with 59 additions and 366 deletions

1
.gitignore vendored
View File

@ -18,6 +18,7 @@ data/*
*.cu *.cu
*.pt *.pt
*.pth *.pth
*.pdf
# template # template

View File

@ -3,5 +3,5 @@
<component name="JavaScriptSettings"> <component name="JavaScriptSettings">
<option name="languageLevel" value="ES6" /> <option name="languageLevel" value="ES6" />
</component> </component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (torch_nightly)" project-jdk-type="Python SDK" /> <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (torch18)" project-jdk-type="Python SDK" />
</project> </project>

View File

@ -9,7 +9,7 @@
<excludeFolder url="file://$MODULE_DIR$/results" /> <excludeFolder url="file://$MODULE_DIR$/results" />
<excludeFolder url="file://$MODULE_DIR$/tb_logger" /> <excludeFolder url="file://$MODULE_DIR$/tb_logger" />
</content> </content>
<orderEntry type="jdk" jdkName="Python 3.8 (torch_nightly)" jdkType="Python SDK" /> <orderEntry type="jdk" jdkName="Python 3.8 (torch18)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" /> <orderEntry type="sourceFolder" forTests="false" />
</component> </component>
<component name="PyDocumentationSettings"> <component name="PyDocumentationSettings">

View File

@ -5,10 +5,14 @@ import random
import cv2 import cv2
import kornia import kornia
import numpy as np import numpy as np
import pytorch_ssim
import torch import torch
import os import os
import torchvision
from torch.utils.data import DataLoader
from torchvision.transforms import Normalize from torchvision.transforms import Normalize
from tqdm import tqdm
from data import util from data import util
# Builds a dataset created from a simple folder containing a list of training/test/validation images. # Builds a dataset created from a simple folder containing a list of training/test/validation images.
@ -140,7 +144,7 @@ class ImageFolderDataset:
if self.normalize: if self.normalize:
hq = self.normalize(hq) hq = self.normalize(hq)
out_dict = {'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item]} out_dict = {'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item], 'has_alt': False}
if self.fetch_alt_image: if self.fetch_alt_image:
# This works by assuming a specific filename structure as would produced by ffmpeg. ex: # This works by assuming a specific filename structure as would produced by ffmpeg. ex:
@ -165,6 +169,7 @@ class ImageFolderDataset:
alt_hq = util.read_img(None, next_img, rgb=True) alt_hq = util.read_img(None, next_img, rgb=True)
alt_hs = self.resize_hq([alt_hq]) alt_hs = self.resize_hq([alt_hq])
alt_hq = torch.from_numpy(np.ascontiguousarray(np.transpose(alt_hs[0], (2, 0, 1)))).float() alt_hq = torch.from_numpy(np.ascontiguousarray(np.transpose(alt_hs[0], (2, 0, 1)))).float()
out_dict['has_alt'] = True
if not self.skip_lq: if not self.skip_lq:
for_lq.append(alt_hs[0]) for_lq.append(alt_hs[0])
except: except:
@ -200,33 +205,25 @@ class ImageFolderDataset:
if __name__ == '__main__': if __name__ == '__main__':
opt = { opt = {
'name': 'amalgam', 'name': 'amalgam',
'paths': ['F:\\4k6k\\datasets\\images\\youtube\\4k_quote_unquote\\images'], 'paths': ['E:\\4k6k\\datasets\\ns_images\\256_unsupervised'],
'weights': [1], 'weights': [1],
'target_size': 256, 'target_size': 256,
'force_multiple': 32, 'force_multiple': 1,
'scale': 2, 'scale': 2,
'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'], 'corrupt_before_downsize': True,
'random_corruptions': ['noise-5', 'none'],
'num_corrupts_per_image': 1,
'corrupt_before_downsize': False,
'fetch_alt_image': True, 'fetch_alt_image': True,
#'labeler': { 'disable_flip': True,
# 'type': 'patch_labels', 'fixed_corruptions': [ 'jpeg-broad' ],
# 'label_file': 'F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\categories_new.json' 'num_corrupts_per_image': 0,
#} 'corruption_blur_scale': 0
} }
ds = ImageFolderDataset(opt) ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=2)
import os import os
os.makedirs("debug", exist_ok=True) output_path = 'E:\\4k6k\\datasets\\ns_images\\128_unsupervised'
for i in range(0, len(ds)): os.makedirs(output_path, exist_ok=True)
o = ds[random.randint(0, len(ds)-1)] for i, d in tqdm(enumerate(ds)):
hq = o['lq'] lq = d['lq']
#masked = (o['labels_mask'] * .5 + .5) * hq torchvision.utils.save_image(lq[:,:,16:-16,:], f'{output_path}\\{i+500000}.png')
import torchvision if i >= 200000:
torchvision.utils.save_image(hq.unsqueeze(0), "debug/%i_lq.png" % (i,)) break
torchvision.utils.save_image(o['alt_lq'].unsqueeze(0), "debug/%i_lq_alt.png" % (i,))
#if len(o['labels'].unique()) > 1:
# randlbl = np.random.choice(o['labels'].unique()[1:])
# moremask = hq * ((1*(o['labels'] == randlbl))*.5+.5)
# torchvision.utils.save_image(moremask.unsqueeze(0), "debug/%i_%s.png" % (i, o['label_strings'][randlbl]))

View File

View File

@ -208,7 +208,7 @@ def run_tsne_instance_level():
print("Run Y = tsne.tsne(X, no_dims, perplexity) to perform t-SNE on your dataset.") print("Run Y = tsne.tsne(X, no_dims, perplexity) to perform t-SNE on your dataset.")
limit = 4000 limit = 4000
X, files = torch.load('results.pth') X, files = torch.load('../results_instance_resnet.pth')
zipped = list(zip(X, files)) zipped = list(zip(X, files))
shuffle(zipped) shuffle(zipped)
X, files = zip(*zipped) X, files = zip(*zipped)
@ -242,7 +242,7 @@ def run_tsne_instance_level():
# Uses the results from the calculation above to create a **massive** pdf plot that shows 1/8 size images on the tsne # Uses the results from the calculation above to create a **massive** pdf plot that shows 1/8 size images on the tsne
# spectrum. # spectrum.
def plot_instance_level_results_as_image_graph(): def plot_instance_level_results_as_image_graph():
Y, files = torch.load('tsne_output.pth') Y, files = torch.load('../tsne_output.pth')
fig, ax = pyplot.subplots() fig, ax = pyplot.subplots()
fig.set_size_inches(200,200,forward=True) fig.set_size_inches(200,200,forward=True)
ax.update_datalim(np.column_stack([Y[:, 0], Y[:, 1]])) ax.update_datalim(np.column_stack([Y[:, 0], Y[:, 1]]))
@ -250,7 +250,7 @@ def plot_instance_level_results_as_image_graph():
for b in tqdm(range(Y.shape[0])): for b in tqdm(range(Y.shape[0])):
im = pyplot.imread(files[b]) im = pyplot.imread(files[b])
im = OffsetImage(im, zoom=1/8) im = OffsetImage(im, zoom=1/2)
ab = AnnotationBbox(im, (Y[b, 0], Y[b, 1]), xycoords='data', frameon=False) ab = AnnotationBbox(im, (Y[b, 0], Y[b, 1]), xycoords='data', frameon=False)
ax.add_artist(ab) ax.add_artist(ab)
ax.scatter(Y[:, 0], Y[:, 1]) ax.scatter(Y[:, 0], Y[:, 1])
@ -277,7 +277,7 @@ def run_tsne_pixel_level():
''' '''
# For resnet-style latent tuples # For resnet-style latent tuples
X, files = torch.load('../results.pth') X, files = torch.load('../../results/2021-4-8-imgset-latent-dict.pth')
zipped = list(zip(X, files)) zipped = list(zip(X, files))
shuffle(zipped) shuffle(zipped)
X, files = zip(*zipped) X, files = zip(*zipped)
@ -347,8 +347,8 @@ def plot_pixel_level_results_as_image_graph():
if __name__ == "__main__": if __name__ == "__main__":
# For use with instance-level results (e.g. from byol_resnet_playground.py) # For use with instance-level results (e.g. from byol_resnet_playground.py)
#run_tsne_instance_level() #run_tsne_instance_level()
#plot_instance_level_results_as_image_graph() plot_instance_level_results_as_image_graph()
# For use with pixel-level results (e.g. from byol_uresnet_playground) # For use with pixel-level results (e.g. from byol_uresnet_playground)
#run_tsne_pixel_level() #run_tsne_pixel_level()
plot_pixel_level_results_as_image_graph() #plot_pixel_level_results_as_image_graph()

View File

@ -19,9 +19,9 @@ def main():
# compression time. If read raw images during training, use 0 for faster IO speed. # compression time. If read raw images during training, use 0 for faster IO speed.
opt['dest'] = 'file' opt['dest'] = 'file'
opt['input_folder'] = ['E:\\4k6k\\datasets\\images\\faces\\CelebAMask-HQ\\CelebA-HQ-img'] opt['input_folder'] = ['E:\\4k6k\\datasets\\images\\lsun\\lsun\\cats']
opt['save_folder'] = 'E:\\4k6k\\datasets\\images\\faces\\CelebAMask-HQ\\256px' opt['save_folder'] = 'E:\\4k6k\\datasets\\images\\lsun\\lsun\\cats\\256_4_by_3'
opt['imgsize'] = 256 opt['imgsize'] = (256,192)
opt['bottom_crop'] = 0 opt['bottom_crop'] = 0
opt['keep_folder'] = False opt['keep_folder'] = False
@ -66,15 +66,25 @@ class TiledDataset(data.Dataset):
h, w, c = img.shape h, w, c = img.shape
# Uncomment to filter any image that doesnt meet a threshold size. # Uncomment to filter any image that doesnt meet a threshold size.
if min(h,w) < self.opt['imgsize']: imgsz_w, imgsz_h = self.opt['imgsize']
if w < imgsz_w or h < imgsz_h:
print("Skipping due to threshold") print("Skipping due to threshold")
return None return None
# We must convert the image into a square. # We must first center-crop the image to the proper aspect ratio
dim = min(h, w) aspect_ratio = imgsz_h / imgsz_w
# Crop the image so that only the center is left, since this is often the most salient part of the image. if h < w * aspect_ratio:
img = img[(h - dim) // 2:dim + (h - dim) // 2, (w - dim) // 2:dim + (w - dim) // 2, :] hdim = h
img = cv2.resize(img, (self.opt['imgsize'], self.opt['imgsize']), interpolation=cv2.INTER_AREA) wdim = int(h / aspect_ratio)
elif w * aspect_ratio < h:
hdim = int(w * aspect_ratio)
wdim = w
else:
hdim = h
wdim = w
img = img[(h - hdim) // 2:hdim + (h - hdim) // 2, (w - wdim) // 2:wdim + (w - wdim) // 2, :]
img = cv2.resize(img, (imgsz_w, imgsz_h), interpolation=cv2.INTER_AREA)
output_folder = self.opt['save_folder'] output_folder = self.opt['save_folder']
if self.opt['keep_folder']: if self.opt['keep_folder']:
# Attempt to find the folder name one level above opt['input_folder'] and use that. # Attempt to find the folder name one level above opt['input_folder'] and use that.

View File

@ -1,325 +0,0 @@
import os
import math
import argparse
import random
import logging
from tqdm import tqdm
import torch
from data.data_sampler import DistIterSampler
from trainer.eval.evaluator import create_evaluator
from utils import util, options as option
from data import create_dataloader, create_dataset
from trainer.ExtensibleTrainer import ExtensibleTrainer
from time import time
def init_dist(backend, **kwargs):
# These packages have globals that screw with Windows, so only import them if needed.
import torch.distributed as dist
import torch.multiprocessing as mp
"""initialization for distributed training"""
if mp.get_start_method(allow_none=True) != 'spawn':
mp.set_start_method('spawn')
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
class Trainer:
def init(self, opt, launcher, all_networks={}):
self._profile = False
self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'].keys() else True
self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'].keys() else True
#### loading resume state if exists
if opt['path'].get('resume_state', None):
# distributed resuming: all load into default GPU
device_id = torch.cuda.current_device()
resume_state = torch.load(opt['path']['resume_state'],
map_location=lambda storage, loc: storage.cuda(device_id))
option.check_resume(opt, resume_state['iter']) # check resume options
else:
resume_state = None
#### mkdir and loggers
if self.rank <= 0: # normal training (self.rank -1) OR distributed training (self.rank 0)
if resume_state is None:
util.mkdir_and_rename(
opt['path']['experiments_root']) # rename experiment folder if exists
util.mkdirs(
(path for key, path in opt['path'].items() if not key == 'experiments_root' and path is not None
and 'pretrain_model' not in key and 'resume' not in key))
# config loggers. Before it, the log will not work
util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
screen=True, tofile=True)
self.logger = logging.getLogger('base')
self.logger.info(option.dict2str(opt))
# tensorboard logger
if opt['use_tb_logger'] and 'debug' not in opt['name']:
self.tb_logger_path = os.path.join(opt['path']['experiments_root'], 'tb_logger')
version = float(torch.__version__[0:3])
if version >= 1.1: # PyTorch 1.1
from torch.utils.tensorboard import SummaryWriter
else:
self.self.logger.info(
'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
from tensorboardX import SummaryWriter
self.tb_logger = SummaryWriter(log_dir=self.tb_logger_path)
else:
util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
self.logger = logging.getLogger('base')
# convert to NoneDict, which returns None for missing keys
opt = option.dict_to_nonedict(opt)
self.opt = opt
#### wandb init
if opt['wandb']:
import wandb
os.makedirs(os.path.join(opt['path']['log'], 'wandb'), exist_ok=True)
wandb.init(project=opt['name'], dir=opt['path']['log'])
#### random seed
seed = opt['train']['manual_seed']
if seed is None:
seed = random.randint(1, 10000)
if self.rank <= 0:
self.logger.info('Random seed: {}'.format(seed))
seed += self.rank # Different multiprocessing instances should behave differently.
util.set_random_seed(seed)
torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = True
# torch.autograd.set_detect_anomaly(True)
# Save the compiled opt dict to the global loaded_options variable.
util.loaded_options = opt
#### create train and val dataloader
dataset_ratio = 1 # enlarge the size of each epoch
for phase, dataset_opt in opt['datasets'].items():
if phase == 'train':
self.train_set = create_dataset(dataset_opt)
train_size = int(math.ceil(len(self.train_set) / dataset_opt['batch_size']))
total_iters = int(opt['train']['niter'])
self.total_epochs = int(math.ceil(total_iters / train_size))
if opt['dist']:
self.train_sampler = DistIterSampler(self.train_set, self.world_size, self.rank, dataset_ratio)
self.total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
else:
self.train_sampler = None
self.train_loader = create_dataloader(self.train_set, dataset_opt, opt, self.train_sampler)
if self.rank <= 0:
self.logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
len(self.train_set), train_size))
self.logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
self.total_epochs, total_iters))
elif phase == 'val':
self.val_set = create_dataset(dataset_opt)
self.val_loader = create_dataloader(self.val_set, dataset_opt, opt, None)
if self.rank <= 0:
self.logger.info('Number of val images in [{:s}]: {:d}'.format(
dataset_opt['name'], len(self.val_set)))
else:
raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
assert self.train_loader is not None
#### create model
self.model = ExtensibleTrainer(opt, cached_networks=all_networks)
### Evaluators
self.evaluators = []
if 'evaluators' in opt['eval'].keys():
for ev_key, ev_opt in opt['eval']['evaluators'].items():
self.evaluators.append(create_evaluator(self.model.networks[ev_opt['for']],
ev_opt, self.model.env))
#### resume training
if resume_state:
self.logger.info('Resuming training from epoch: {}, iter: {}.'.format(
resume_state['epoch'], resume_state['iter']))
self.start_epoch = resume_state['epoch']
self.current_step = resume_state['iter']
self.model.resume_training(resume_state, 'amp_opt_level' in opt.keys()) # handle optimizers and schedulers
else:
self.current_step = -1 if 'start_step' not in opt.keys() else opt['start_step']
self.start_epoch = 0
if 'force_start_step' in opt.keys():
self.current_step = opt['force_start_step']
opt['current_step'] = self.current_step
def do_step(self, train_data):
if self._profile:
print("Data fetch: %f" % (time() - _t))
_t = time()
opt = self.opt
self.current_step += 1
#### update learning rate
self.model.update_learning_rate(self.current_step, warmup_iter=opt['train']['warmup_iter'])
#### training
if self._profile:
print("Update LR: %f" % (time() - _t))
_t = time()
self.model.feed_data(train_data, self.current_step)
self.model.optimize_parameters(self.current_step)
if self._profile:
print("Model feed + step: %f" % (time() - _t))
_t = time()
#### log
if self.current_step % opt['logger']['print_freq'] == 0 and self.rank <= 0:
logs = self.model.get_current_log(self.current_step)
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(self.epoch, self.current_step)
for v in self.model.get_current_learning_rate():
message += '{:.3e},'.format(v)
message += ')] '
for k, v in logs.items():
if 'histogram' in k:
self.tb_logger.add_histogram(k, v, self.current_step)
elif isinstance(v, dict):
self.tb_logger.add_scalars(k, v, self.current_step)
else:
message += '{:s}: {:.4e} '.format(k, v)
# tensorboard logger
if opt['use_tb_logger'] and 'debug' not in opt['name']:
self.tb_logger.add_scalar(k, v, self.current_step)
if opt['wandb']:
import wandb
wandb.log(logs)
self.logger.info(message)
#### save models and training states
if self.current_step % opt['logger']['save_checkpoint_freq'] == 0:
if self.rank <= 0:
self.logger.info('Saving models and training states.')
self.model.save(self.current_step)
self.model.save_training_state(self.epoch, self.current_step)
if 'alt_path' in opt['path'].keys():
import shutil
print("Synchronizing tb_logger to alt_path..")
alt_tblogger = os.path.join(opt['path']['alt_path'], "tb_logger")
shutil.rmtree(alt_tblogger, ignore_errors=True)
shutil.copytree(self.tb_logger_path, alt_tblogger)
#### validation
if opt['datasets'].get('val', None) and self.current_step % opt['train']['val_freq'] == 0:
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan',
'extensibletrainer'] and self.rank <= 0: # image restoration validation
avg_psnr = 0.
avg_fea_loss = 0.
idx = 0
val_tqdm = tqdm(self.val_loader)
for val_data in val_tqdm:
idx += 1
for b in range(len(val_data['HQ_path'])):
img_name = os.path.splitext(os.path.basename(val_data['HQ_path'][b]))[0]
img_dir = os.path.join(opt['path']['val_images'], img_name)
util.mkdir(img_dir)
self.model.feed_data(val_data, self.current_step)
self.model.test()
visuals = self.model.get_current_visuals()
if visuals is None:
continue
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
# calculate PSNR
if self.val_compute_psnr:
gt_img = util.tensor2img(visuals['hq'][b]) # uint8
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
avg_psnr += util.calculate_psnr(sr_img, gt_img)
# calculate fea loss
if self.val_compute_fea:
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['hq'][b])
# Save SR images for reference
img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step)
save_img_path = os.path.join(img_dir, img_base_name)
util.save_img(sr_img, save_img_path)
avg_psnr = avg_psnr / idx
avg_fea_loss = avg_fea_loss / idx
# log
self.logger.info('# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(avg_psnr, avg_fea_loss))
# tensorboard logger
if opt['use_tb_logger'] and 'debug' not in opt['name'] and self.rank <= 0:
self.tb_logger.add_scalar('val_psnr', avg_psnr, self.current_step)
self.tb_logger.add_scalar('val_fea', avg_fea_loss, self.current_step)
if len(self.evaluators) != 0 and self.current_step % opt['train']['val_freq'] == 0 and self.rank <= 0:
eval_dict = {}
for eval in self.evaluators:
eval_dict.update(eval.perform_eval())
if self.rank <= 0:
print("Evaluator results: ", eval_dict)
for ek, ev in eval_dict.items():
self.tb_logger.add_scalar(ek, ev, self.current_step)
def do_training(self):
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
for epoch in range(self.start_epoch, self.total_epochs + 1):
self.epoch = epoch
if opt['dist']:
self.train_sampler.set_epoch(epoch)
tq_ldr = tqdm(self.train_loader)
_t = time()
for train_data in tq_ldr:
self.do_step(train_data)
def create_training_generator(self, index):
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
for epoch in range(self.start_epoch, self.total_epochs + 1):
self.epoch = epoch
if self.opt['dist']:
self.train_sampler.set_epoch(epoch)
tq_ldr = tqdm(self.train_loader, position=index)
_t = time()
for train_data in tq_ldr:
yield self.model
self.do_step(train_data)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixpro_3.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)
if args.launcher != 'none':
# export CUDA_VISIBLE_DEVICES for running in distributed mode.
if 'gpu_ids' in opt.keys():
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
trainer = Trainer()
#### distributed training settings
if args.launcher == 'none': # disabled distributed training
opt['dist'] = False
trainer.rank = -1
if len(opt['gpu_ids']) == 1:
torch.cuda.set_device(opt['gpu_ids'][0])
print('Disabled distributed training.')
else:
opt['dist'] = True
init_dist('nccl')
trainer.world_size = torch.distributed.get_world_size()
trainer.rank = torch.distributed.get_rank()
trainer.init(opt, args.launcher)
trainer.do_training()

View File

@ -0,0 +1,10 @@
# StyleGAN Implementations
DLAS supports two different StyleGAN2 implementations:
- [@rosinality implementation](https://github.com/rosinality/stylegan2-pytorch/commits/master)
Designed to reach parity with the nVidia reference implementation in TF1.5
- [@lucidrains implementation](https://github.com/lucidrains/stylegan2-pytorch)
Designed with simplicity and readability in mind.
I prefer the readability of @lucidrains implementation, but you cannot (yet) use pretrained weights
with it. I'm working on that.