Misc
This commit is contained in:
parent
94e069bced
commit
b687ef4cd0
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -18,6 +18,7 @@ data/*
|
||||||
*.cu
|
*.cu
|
||||||
*.pt
|
*.pt
|
||||||
*.pth
|
*.pth
|
||||||
|
*.pdf
|
||||||
|
|
||||||
# template
|
# template
|
||||||
|
|
||||||
|
|
|
@ -3,5 +3,5 @@
|
||||||
<component name="JavaScriptSettings">
|
<component name="JavaScriptSettings">
|
||||||
<option name="languageLevel" value="ES6" />
|
<option name="languageLevel" value="ES6" />
|
||||||
</component>
|
</component>
|
||||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (torch_nightly)" project-jdk-type="Python SDK" />
|
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (torch18)" project-jdk-type="Python SDK" />
|
||||||
</project>
|
</project>
|
|
@ -9,7 +9,7 @@
|
||||||
<excludeFolder url="file://$MODULE_DIR$/results" />
|
<excludeFolder url="file://$MODULE_DIR$/results" />
|
||||||
<excludeFolder url="file://$MODULE_DIR$/tb_logger" />
|
<excludeFolder url="file://$MODULE_DIR$/tb_logger" />
|
||||||
</content>
|
</content>
|
||||||
<orderEntry type="jdk" jdkName="Python 3.8 (torch_nightly)" jdkType="Python SDK" />
|
<orderEntry type="jdk" jdkName="Python 3.8 (torch18)" jdkType="Python SDK" />
|
||||||
<orderEntry type="sourceFolder" forTests="false" />
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
</component>
|
</component>
|
||||||
<component name="PyDocumentationSettings">
|
<component name="PyDocumentationSettings">
|
||||||
|
|
|
@ -5,10 +5,14 @@ import random
|
||||||
import cv2
|
import cv2
|
||||||
import kornia
|
import kornia
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytorch_ssim
|
||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import torchvision
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
from torchvision.transforms import Normalize
|
from torchvision.transforms import Normalize
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from data import util
|
from data import util
|
||||||
# Builds a dataset created from a simple folder containing a list of training/test/validation images.
|
# Builds a dataset created from a simple folder containing a list of training/test/validation images.
|
||||||
|
@ -140,7 +144,7 @@ class ImageFolderDataset:
|
||||||
if self.normalize:
|
if self.normalize:
|
||||||
hq = self.normalize(hq)
|
hq = self.normalize(hq)
|
||||||
|
|
||||||
out_dict = {'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item]}
|
out_dict = {'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item], 'has_alt': False}
|
||||||
|
|
||||||
if self.fetch_alt_image:
|
if self.fetch_alt_image:
|
||||||
# This works by assuming a specific filename structure as would produced by ffmpeg. ex:
|
# This works by assuming a specific filename structure as would produced by ffmpeg. ex:
|
||||||
|
@ -165,6 +169,7 @@ class ImageFolderDataset:
|
||||||
alt_hq = util.read_img(None, next_img, rgb=True)
|
alt_hq = util.read_img(None, next_img, rgb=True)
|
||||||
alt_hs = self.resize_hq([alt_hq])
|
alt_hs = self.resize_hq([alt_hq])
|
||||||
alt_hq = torch.from_numpy(np.ascontiguousarray(np.transpose(alt_hs[0], (2, 0, 1)))).float()
|
alt_hq = torch.from_numpy(np.ascontiguousarray(np.transpose(alt_hs[0], (2, 0, 1)))).float()
|
||||||
|
out_dict['has_alt'] = True
|
||||||
if not self.skip_lq:
|
if not self.skip_lq:
|
||||||
for_lq.append(alt_hs[0])
|
for_lq.append(alt_hs[0])
|
||||||
except:
|
except:
|
||||||
|
@ -200,33 +205,25 @@ class ImageFolderDataset:
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
opt = {
|
opt = {
|
||||||
'name': 'amalgam',
|
'name': 'amalgam',
|
||||||
'paths': ['F:\\4k6k\\datasets\\images\\youtube\\4k_quote_unquote\\images'],
|
'paths': ['E:\\4k6k\\datasets\\ns_images\\256_unsupervised'],
|
||||||
'weights': [1],
|
'weights': [1],
|
||||||
'target_size': 256,
|
'target_size': 256,
|
||||||
'force_multiple': 32,
|
'force_multiple': 1,
|
||||||
'scale': 2,
|
'scale': 2,
|
||||||
'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'],
|
'corrupt_before_downsize': True,
|
||||||
'random_corruptions': ['noise-5', 'none'],
|
|
||||||
'num_corrupts_per_image': 1,
|
|
||||||
'corrupt_before_downsize': False,
|
|
||||||
'fetch_alt_image': True,
|
'fetch_alt_image': True,
|
||||||
#'labeler': {
|
'disable_flip': True,
|
||||||
# 'type': 'patch_labels',
|
'fixed_corruptions': [ 'jpeg-broad' ],
|
||||||
# 'label_file': 'F:\\4k6k\\datasets\\ns_images\\512_unsupervised\\categories_new.json'
|
'num_corrupts_per_image': 0,
|
||||||
#}
|
'corruption_blur_scale': 0
|
||||||
}
|
}
|
||||||
|
|
||||||
ds = ImageFolderDataset(opt)
|
ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=2)
|
||||||
import os
|
import os
|
||||||
os.makedirs("debug", exist_ok=True)
|
output_path = 'E:\\4k6k\\datasets\\ns_images\\128_unsupervised'
|
||||||
for i in range(0, len(ds)):
|
os.makedirs(output_path, exist_ok=True)
|
||||||
o = ds[random.randint(0, len(ds)-1)]
|
for i, d in tqdm(enumerate(ds)):
|
||||||
hq = o['lq']
|
lq = d['lq']
|
||||||
#masked = (o['labels_mask'] * .5 + .5) * hq
|
torchvision.utils.save_image(lq[:,:,16:-16,:], f'{output_path}\\{i+500000}.png')
|
||||||
import torchvision
|
if i >= 200000:
|
||||||
torchvision.utils.save_image(hq.unsqueeze(0), "debug/%i_lq.png" % (i,))
|
break
|
||||||
torchvision.utils.save_image(o['alt_lq'].unsqueeze(0), "debug/%i_lq_alt.png" % (i,))
|
|
||||||
#if len(o['labels'].unique()) > 1:
|
|
||||||
# randlbl = np.random.choice(o['labels'].unique()[1:])
|
|
||||||
# moremask = hq * ((1*(o['labels'] == randlbl))*.5+.5)
|
|
||||||
# torchvision.utils.save_image(moremask.unsqueeze(0), "debug/%i_%s.png" % (i, o['label_strings'][randlbl]))
|
|
0
codes/models/global_convs/gc_resnet.py
Normal file
0
codes/models/global_convs/gc_resnet.py
Normal file
|
@ -208,7 +208,7 @@ def run_tsne_instance_level():
|
||||||
print("Run Y = tsne.tsne(X, no_dims, perplexity) to perform t-SNE on your dataset.")
|
print("Run Y = tsne.tsne(X, no_dims, perplexity) to perform t-SNE on your dataset.")
|
||||||
|
|
||||||
limit = 4000
|
limit = 4000
|
||||||
X, files = torch.load('results.pth')
|
X, files = torch.load('../results_instance_resnet.pth')
|
||||||
zipped = list(zip(X, files))
|
zipped = list(zip(X, files))
|
||||||
shuffle(zipped)
|
shuffle(zipped)
|
||||||
X, files = zip(*zipped)
|
X, files = zip(*zipped)
|
||||||
|
@ -242,7 +242,7 @@ def run_tsne_instance_level():
|
||||||
# Uses the results from the calculation above to create a **massive** pdf plot that shows 1/8 size images on the tsne
|
# Uses the results from the calculation above to create a **massive** pdf plot that shows 1/8 size images on the tsne
|
||||||
# spectrum.
|
# spectrum.
|
||||||
def plot_instance_level_results_as_image_graph():
|
def plot_instance_level_results_as_image_graph():
|
||||||
Y, files = torch.load('tsne_output.pth')
|
Y, files = torch.load('../tsne_output.pth')
|
||||||
fig, ax = pyplot.subplots()
|
fig, ax = pyplot.subplots()
|
||||||
fig.set_size_inches(200,200,forward=True)
|
fig.set_size_inches(200,200,forward=True)
|
||||||
ax.update_datalim(np.column_stack([Y[:, 0], Y[:, 1]]))
|
ax.update_datalim(np.column_stack([Y[:, 0], Y[:, 1]]))
|
||||||
|
@ -250,7 +250,7 @@ def plot_instance_level_results_as_image_graph():
|
||||||
|
|
||||||
for b in tqdm(range(Y.shape[0])):
|
for b in tqdm(range(Y.shape[0])):
|
||||||
im = pyplot.imread(files[b])
|
im = pyplot.imread(files[b])
|
||||||
im = OffsetImage(im, zoom=1/8)
|
im = OffsetImage(im, zoom=1/2)
|
||||||
ab = AnnotationBbox(im, (Y[b, 0], Y[b, 1]), xycoords='data', frameon=False)
|
ab = AnnotationBbox(im, (Y[b, 0], Y[b, 1]), xycoords='data', frameon=False)
|
||||||
ax.add_artist(ab)
|
ax.add_artist(ab)
|
||||||
ax.scatter(Y[:, 0], Y[:, 1])
|
ax.scatter(Y[:, 0], Y[:, 1])
|
||||||
|
@ -277,7 +277,7 @@ def run_tsne_pixel_level():
|
||||||
'''
|
'''
|
||||||
|
|
||||||
# For resnet-style latent tuples
|
# For resnet-style latent tuples
|
||||||
X, files = torch.load('../results.pth')
|
X, files = torch.load('../../results/2021-4-8-imgset-latent-dict.pth')
|
||||||
zipped = list(zip(X, files))
|
zipped = list(zip(X, files))
|
||||||
shuffle(zipped)
|
shuffle(zipped)
|
||||||
X, files = zip(*zipped)
|
X, files = zip(*zipped)
|
||||||
|
@ -347,8 +347,8 @@ def plot_pixel_level_results_as_image_graph():
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# For use with instance-level results (e.g. from byol_resnet_playground.py)
|
# For use with instance-level results (e.g. from byol_resnet_playground.py)
|
||||||
#run_tsne_instance_level()
|
#run_tsne_instance_level()
|
||||||
#plot_instance_level_results_as_image_graph()
|
plot_instance_level_results_as_image_graph()
|
||||||
|
|
||||||
# For use with pixel-level results (e.g. from byol_uresnet_playground)
|
# For use with pixel-level results (e.g. from byol_uresnet_playground)
|
||||||
#run_tsne_pixel_level()
|
#run_tsne_pixel_level()
|
||||||
plot_pixel_level_results_as_image_graph()
|
#plot_pixel_level_results_as_image_graph()
|
|
@ -19,9 +19,9 @@ def main():
|
||||||
# compression time. If read raw images during training, use 0 for faster IO speed.
|
# compression time. If read raw images during training, use 0 for faster IO speed.
|
||||||
|
|
||||||
opt['dest'] = 'file'
|
opt['dest'] = 'file'
|
||||||
opt['input_folder'] = ['E:\\4k6k\\datasets\\images\\faces\\CelebAMask-HQ\\CelebA-HQ-img']
|
opt['input_folder'] = ['E:\\4k6k\\datasets\\images\\lsun\\lsun\\cats']
|
||||||
opt['save_folder'] = 'E:\\4k6k\\datasets\\images\\faces\\CelebAMask-HQ\\256px'
|
opt['save_folder'] = 'E:\\4k6k\\datasets\\images\\lsun\\lsun\\cats\\256_4_by_3'
|
||||||
opt['imgsize'] = 256
|
opt['imgsize'] = (256,192)
|
||||||
opt['bottom_crop'] = 0
|
opt['bottom_crop'] = 0
|
||||||
opt['keep_folder'] = False
|
opt['keep_folder'] = False
|
||||||
|
|
||||||
|
@ -66,15 +66,25 @@ class TiledDataset(data.Dataset):
|
||||||
|
|
||||||
h, w, c = img.shape
|
h, w, c = img.shape
|
||||||
# Uncomment to filter any image that doesnt meet a threshold size.
|
# Uncomment to filter any image that doesnt meet a threshold size.
|
||||||
if min(h,w) < self.opt['imgsize']:
|
imgsz_w, imgsz_h = self.opt['imgsize']
|
||||||
|
if w < imgsz_w or h < imgsz_h:
|
||||||
print("Skipping due to threshold")
|
print("Skipping due to threshold")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# We must convert the image into a square.
|
# We must first center-crop the image to the proper aspect ratio
|
||||||
dim = min(h, w)
|
aspect_ratio = imgsz_h / imgsz_w
|
||||||
# Crop the image so that only the center is left, since this is often the most salient part of the image.
|
if h < w * aspect_ratio:
|
||||||
img = img[(h - dim) // 2:dim + (h - dim) // 2, (w - dim) // 2:dim + (w - dim) // 2, :]
|
hdim = h
|
||||||
img = cv2.resize(img, (self.opt['imgsize'], self.opt['imgsize']), interpolation=cv2.INTER_AREA)
|
wdim = int(h / aspect_ratio)
|
||||||
|
elif w * aspect_ratio < h:
|
||||||
|
hdim = int(w * aspect_ratio)
|
||||||
|
wdim = w
|
||||||
|
else:
|
||||||
|
hdim = h
|
||||||
|
wdim = w
|
||||||
|
img = img[(h - hdim) // 2:hdim + (h - hdim) // 2, (w - wdim) // 2:wdim + (w - wdim) // 2, :]
|
||||||
|
|
||||||
|
img = cv2.resize(img, (imgsz_w, imgsz_h), interpolation=cv2.INTER_AREA)
|
||||||
output_folder = self.opt['save_folder']
|
output_folder = self.opt['save_folder']
|
||||||
if self.opt['keep_folder']:
|
if self.opt['keep_folder']:
|
||||||
# Attempt to find the folder name one level above opt['input_folder'] and use that.
|
# Attempt to find the folder name one level above opt['input_folder'] and use that.
|
||||||
|
|
325
codes/train2.py
325
codes/train2.py
|
@ -1,325 +0,0 @@
|
||||||
import os
|
|
||||||
import math
|
|
||||||
import argparse
|
|
||||||
import random
|
|
||||||
import logging
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from data.data_sampler import DistIterSampler
|
|
||||||
from trainer.eval.evaluator import create_evaluator
|
|
||||||
|
|
||||||
from utils import util, options as option
|
|
||||||
from data import create_dataloader, create_dataset
|
|
||||||
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
|
||||||
from time import time
|
|
||||||
|
|
||||||
def init_dist(backend, **kwargs):
|
|
||||||
# These packages have globals that screw with Windows, so only import them if needed.
|
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.multiprocessing as mp
|
|
||||||
|
|
||||||
"""initialization for distributed training"""
|
|
||||||
if mp.get_start_method(allow_none=True) != 'spawn':
|
|
||||||
mp.set_start_method('spawn')
|
|
||||||
rank = int(os.environ['RANK'])
|
|
||||||
num_gpus = torch.cuda.device_count()
|
|
||||||
torch.cuda.set_device(rank % num_gpus)
|
|
||||||
dist.init_process_group(backend=backend, **kwargs)
|
|
||||||
|
|
||||||
class Trainer:
|
|
||||||
|
|
||||||
def init(self, opt, launcher, all_networks={}):
|
|
||||||
self._profile = False
|
|
||||||
self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'].keys() else True
|
|
||||||
self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'].keys() else True
|
|
||||||
|
|
||||||
#### loading resume state if exists
|
|
||||||
if opt['path'].get('resume_state', None):
|
|
||||||
# distributed resuming: all load into default GPU
|
|
||||||
device_id = torch.cuda.current_device()
|
|
||||||
resume_state = torch.load(opt['path']['resume_state'],
|
|
||||||
map_location=lambda storage, loc: storage.cuda(device_id))
|
|
||||||
option.check_resume(opt, resume_state['iter']) # check resume options
|
|
||||||
else:
|
|
||||||
resume_state = None
|
|
||||||
|
|
||||||
#### mkdir and loggers
|
|
||||||
if self.rank <= 0: # normal training (self.rank -1) OR distributed training (self.rank 0)
|
|
||||||
if resume_state is None:
|
|
||||||
util.mkdir_and_rename(
|
|
||||||
opt['path']['experiments_root']) # rename experiment folder if exists
|
|
||||||
util.mkdirs(
|
|
||||||
(path for key, path in opt['path'].items() if not key == 'experiments_root' and path is not None
|
|
||||||
and 'pretrain_model' not in key and 'resume' not in key))
|
|
||||||
|
|
||||||
# config loggers. Before it, the log will not work
|
|
||||||
util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
|
|
||||||
screen=True, tofile=True)
|
|
||||||
self.logger = logging.getLogger('base')
|
|
||||||
self.logger.info(option.dict2str(opt))
|
|
||||||
# tensorboard logger
|
|
||||||
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
|
||||||
self.tb_logger_path = os.path.join(opt['path']['experiments_root'], 'tb_logger')
|
|
||||||
version = float(torch.__version__[0:3])
|
|
||||||
if version >= 1.1: # PyTorch 1.1
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
else:
|
|
||||||
self.self.logger.info(
|
|
||||||
'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
|
|
||||||
from tensorboardX import SummaryWriter
|
|
||||||
self.tb_logger = SummaryWriter(log_dir=self.tb_logger_path)
|
|
||||||
else:
|
|
||||||
util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
|
|
||||||
self.logger = logging.getLogger('base')
|
|
||||||
|
|
||||||
# convert to NoneDict, which returns None for missing keys
|
|
||||||
opt = option.dict_to_nonedict(opt)
|
|
||||||
self.opt = opt
|
|
||||||
|
|
||||||
#### wandb init
|
|
||||||
if opt['wandb']:
|
|
||||||
import wandb
|
|
||||||
os.makedirs(os.path.join(opt['path']['log'], 'wandb'), exist_ok=True)
|
|
||||||
wandb.init(project=opt['name'], dir=opt['path']['log'])
|
|
||||||
|
|
||||||
#### random seed
|
|
||||||
seed = opt['train']['manual_seed']
|
|
||||||
if seed is None:
|
|
||||||
seed = random.randint(1, 10000)
|
|
||||||
if self.rank <= 0:
|
|
||||||
self.logger.info('Random seed: {}'.format(seed))
|
|
||||||
seed += self.rank # Different multiprocessing instances should behave differently.
|
|
||||||
util.set_random_seed(seed)
|
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
|
||||||
# torch.backends.cudnn.deterministic = True
|
|
||||||
# torch.autograd.set_detect_anomaly(True)
|
|
||||||
|
|
||||||
# Save the compiled opt dict to the global loaded_options variable.
|
|
||||||
util.loaded_options = opt
|
|
||||||
|
|
||||||
#### create train and val dataloader
|
|
||||||
dataset_ratio = 1 # enlarge the size of each epoch
|
|
||||||
for phase, dataset_opt in opt['datasets'].items():
|
|
||||||
if phase == 'train':
|
|
||||||
self.train_set = create_dataset(dataset_opt)
|
|
||||||
train_size = int(math.ceil(len(self.train_set) / dataset_opt['batch_size']))
|
|
||||||
total_iters = int(opt['train']['niter'])
|
|
||||||
self.total_epochs = int(math.ceil(total_iters / train_size))
|
|
||||||
if opt['dist']:
|
|
||||||
self.train_sampler = DistIterSampler(self.train_set, self.world_size, self.rank, dataset_ratio)
|
|
||||||
self.total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
|
|
||||||
else:
|
|
||||||
self.train_sampler = None
|
|
||||||
self.train_loader = create_dataloader(self.train_set, dataset_opt, opt, self.train_sampler)
|
|
||||||
if self.rank <= 0:
|
|
||||||
self.logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
|
|
||||||
len(self.train_set), train_size))
|
|
||||||
self.logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
|
|
||||||
self.total_epochs, total_iters))
|
|
||||||
elif phase == 'val':
|
|
||||||
self.val_set = create_dataset(dataset_opt)
|
|
||||||
self.val_loader = create_dataloader(self.val_set, dataset_opt, opt, None)
|
|
||||||
if self.rank <= 0:
|
|
||||||
self.logger.info('Number of val images in [{:s}]: {:d}'.format(
|
|
||||||
dataset_opt['name'], len(self.val_set)))
|
|
||||||
else:
|
|
||||||
raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
|
|
||||||
assert self.train_loader is not None
|
|
||||||
|
|
||||||
#### create model
|
|
||||||
self.model = ExtensibleTrainer(opt, cached_networks=all_networks)
|
|
||||||
|
|
||||||
### Evaluators
|
|
||||||
self.evaluators = []
|
|
||||||
if 'evaluators' in opt['eval'].keys():
|
|
||||||
for ev_key, ev_opt in opt['eval']['evaluators'].items():
|
|
||||||
self.evaluators.append(create_evaluator(self.model.networks[ev_opt['for']],
|
|
||||||
ev_opt, self.model.env))
|
|
||||||
|
|
||||||
#### resume training
|
|
||||||
if resume_state:
|
|
||||||
self.logger.info('Resuming training from epoch: {}, iter: {}.'.format(
|
|
||||||
resume_state['epoch'], resume_state['iter']))
|
|
||||||
|
|
||||||
self.start_epoch = resume_state['epoch']
|
|
||||||
self.current_step = resume_state['iter']
|
|
||||||
self.model.resume_training(resume_state, 'amp_opt_level' in opt.keys()) # handle optimizers and schedulers
|
|
||||||
else:
|
|
||||||
self.current_step = -1 if 'start_step' not in opt.keys() else opt['start_step']
|
|
||||||
self.start_epoch = 0
|
|
||||||
if 'force_start_step' in opt.keys():
|
|
||||||
self.current_step = opt['force_start_step']
|
|
||||||
opt['current_step'] = self.current_step
|
|
||||||
|
|
||||||
def do_step(self, train_data):
|
|
||||||
if self._profile:
|
|
||||||
print("Data fetch: %f" % (time() - _t))
|
|
||||||
_t = time()
|
|
||||||
|
|
||||||
opt = self.opt
|
|
||||||
self.current_step += 1
|
|
||||||
#### update learning rate
|
|
||||||
self.model.update_learning_rate(self.current_step, warmup_iter=opt['train']['warmup_iter'])
|
|
||||||
|
|
||||||
#### training
|
|
||||||
if self._profile:
|
|
||||||
print("Update LR: %f" % (time() - _t))
|
|
||||||
_t = time()
|
|
||||||
self.model.feed_data(train_data, self.current_step)
|
|
||||||
self.model.optimize_parameters(self.current_step)
|
|
||||||
if self._profile:
|
|
||||||
print("Model feed + step: %f" % (time() - _t))
|
|
||||||
_t = time()
|
|
||||||
|
|
||||||
#### log
|
|
||||||
if self.current_step % opt['logger']['print_freq'] == 0 and self.rank <= 0:
|
|
||||||
logs = self.model.get_current_log(self.current_step)
|
|
||||||
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(self.epoch, self.current_step)
|
|
||||||
for v in self.model.get_current_learning_rate():
|
|
||||||
message += '{:.3e},'.format(v)
|
|
||||||
message += ')] '
|
|
||||||
for k, v in logs.items():
|
|
||||||
if 'histogram' in k:
|
|
||||||
self.tb_logger.add_histogram(k, v, self.current_step)
|
|
||||||
elif isinstance(v, dict):
|
|
||||||
self.tb_logger.add_scalars(k, v, self.current_step)
|
|
||||||
else:
|
|
||||||
message += '{:s}: {:.4e} '.format(k, v)
|
|
||||||
# tensorboard logger
|
|
||||||
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
|
||||||
self.tb_logger.add_scalar(k, v, self.current_step)
|
|
||||||
if opt['wandb']:
|
|
||||||
import wandb
|
|
||||||
wandb.log(logs)
|
|
||||||
self.logger.info(message)
|
|
||||||
|
|
||||||
#### save models and training states
|
|
||||||
if self.current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
|
||||||
if self.rank <= 0:
|
|
||||||
self.logger.info('Saving models and training states.')
|
|
||||||
self.model.save(self.current_step)
|
|
||||||
self.model.save_training_state(self.epoch, self.current_step)
|
|
||||||
if 'alt_path' in opt['path'].keys():
|
|
||||||
import shutil
|
|
||||||
print("Synchronizing tb_logger to alt_path..")
|
|
||||||
alt_tblogger = os.path.join(opt['path']['alt_path'], "tb_logger")
|
|
||||||
shutil.rmtree(alt_tblogger, ignore_errors=True)
|
|
||||||
shutil.copytree(self.tb_logger_path, alt_tblogger)
|
|
||||||
|
|
||||||
#### validation
|
|
||||||
if opt['datasets'].get('val', None) and self.current_step % opt['train']['val_freq'] == 0:
|
|
||||||
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan',
|
|
||||||
'extensibletrainer'] and self.rank <= 0: # image restoration validation
|
|
||||||
avg_psnr = 0.
|
|
||||||
avg_fea_loss = 0.
|
|
||||||
idx = 0
|
|
||||||
val_tqdm = tqdm(self.val_loader)
|
|
||||||
for val_data in val_tqdm:
|
|
||||||
idx += 1
|
|
||||||
for b in range(len(val_data['HQ_path'])):
|
|
||||||
img_name = os.path.splitext(os.path.basename(val_data['HQ_path'][b]))[0]
|
|
||||||
img_dir = os.path.join(opt['path']['val_images'], img_name)
|
|
||||||
|
|
||||||
util.mkdir(img_dir)
|
|
||||||
|
|
||||||
self.model.feed_data(val_data, self.current_step)
|
|
||||||
self.model.test()
|
|
||||||
|
|
||||||
visuals = self.model.get_current_visuals()
|
|
||||||
if visuals is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
|
|
||||||
# calculate PSNR
|
|
||||||
if self.val_compute_psnr:
|
|
||||||
gt_img = util.tensor2img(visuals['hq'][b]) # uint8
|
|
||||||
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
|
||||||
avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
|
||||||
|
|
||||||
# calculate fea loss
|
|
||||||
if self.val_compute_fea:
|
|
||||||
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['hq'][b])
|
|
||||||
|
|
||||||
# Save SR images for reference
|
|
||||||
img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step)
|
|
||||||
save_img_path = os.path.join(img_dir, img_base_name)
|
|
||||||
util.save_img(sr_img, save_img_path)
|
|
||||||
|
|
||||||
avg_psnr = avg_psnr / idx
|
|
||||||
avg_fea_loss = avg_fea_loss / idx
|
|
||||||
|
|
||||||
# log
|
|
||||||
self.logger.info('# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(avg_psnr, avg_fea_loss))
|
|
||||||
|
|
||||||
# tensorboard logger
|
|
||||||
if opt['use_tb_logger'] and 'debug' not in opt['name'] and self.rank <= 0:
|
|
||||||
self.tb_logger.add_scalar('val_psnr', avg_psnr, self.current_step)
|
|
||||||
self.tb_logger.add_scalar('val_fea', avg_fea_loss, self.current_step)
|
|
||||||
|
|
||||||
if len(self.evaluators) != 0 and self.current_step % opt['train']['val_freq'] == 0 and self.rank <= 0:
|
|
||||||
eval_dict = {}
|
|
||||||
for eval in self.evaluators:
|
|
||||||
eval_dict.update(eval.perform_eval())
|
|
||||||
if self.rank <= 0:
|
|
||||||
print("Evaluator results: ", eval_dict)
|
|
||||||
for ek, ev in eval_dict.items():
|
|
||||||
self.tb_logger.add_scalar(ek, ev, self.current_step)
|
|
||||||
|
|
||||||
def do_training(self):
|
|
||||||
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
|
|
||||||
for epoch in range(self.start_epoch, self.total_epochs + 1):
|
|
||||||
self.epoch = epoch
|
|
||||||
if opt['dist']:
|
|
||||||
self.train_sampler.set_epoch(epoch)
|
|
||||||
tq_ldr = tqdm(self.train_loader)
|
|
||||||
|
|
||||||
_t = time()
|
|
||||||
for train_data in tq_ldr:
|
|
||||||
self.do_step(train_data)
|
|
||||||
|
|
||||||
def create_training_generator(self, index):
|
|
||||||
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
|
|
||||||
for epoch in range(self.start_epoch, self.total_epochs + 1):
|
|
||||||
self.epoch = epoch
|
|
||||||
if self.opt['dist']:
|
|
||||||
self.train_sampler.set_epoch(epoch)
|
|
||||||
tq_ldr = tqdm(self.train_loader, position=index)
|
|
||||||
|
|
||||||
_t = time()
|
|
||||||
for train_data in tq_ldr:
|
|
||||||
yield self.model
|
|
||||||
self.do_step(train_data)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixpro_3.yml')
|
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
|
||||||
args = parser.parse_args()
|
|
||||||
opt = option.parse(args.opt, is_train=True)
|
|
||||||
if args.launcher != 'none':
|
|
||||||
# export CUDA_VISIBLE_DEVICES for running in distributed mode.
|
|
||||||
if 'gpu_ids' in opt.keys():
|
|
||||||
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
|
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
|
|
||||||
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
|
|
||||||
trainer = Trainer()
|
|
||||||
|
|
||||||
#### distributed training settings
|
|
||||||
if args.launcher == 'none': # disabled distributed training
|
|
||||||
opt['dist'] = False
|
|
||||||
trainer.rank = -1
|
|
||||||
if len(opt['gpu_ids']) == 1:
|
|
||||||
torch.cuda.set_device(opt['gpu_ids'][0])
|
|
||||||
print('Disabled distributed training.')
|
|
||||||
else:
|
|
||||||
opt['dist'] = True
|
|
||||||
init_dist('nccl')
|
|
||||||
trainer.world_size = torch.distributed.get_world_size()
|
|
||||||
trainer.rank = torch.distributed.get_rank()
|
|
||||||
|
|
||||||
trainer.init(opt, args.launcher)
|
|
||||||
trainer.do_training()
|
|
10
recipes/stylegan/README.md
Normal file
10
recipes/stylegan/README.md
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
# StyleGAN Implementations
|
||||||
|
DLAS supports two different StyleGAN2 implementations:
|
||||||
|
|
||||||
|
- [@rosinality implementation](https://github.com/rosinality/stylegan2-pytorch/commits/master)
|
||||||
|
Designed to reach parity with the nVidia reference implementation in TF1.5
|
||||||
|
- [@lucidrains implementation](https://github.com/lucidrains/stylegan2-pytorch)
|
||||||
|
Designed with simplicity and readability in mind.
|
||||||
|
|
||||||
|
I prefer the readability of @lucidrains implementation, but you cannot (yet) use pretrained weights
|
||||||
|
with it. I'm working on that.
|
Loading…
Reference in New Issue
Block a user