Merge remote-tracking branch 'origin/gan_lab' into gan_lab

This commit is contained in:
James Betker 2020-12-07 12:51:04 -07:00
commit bca59ed98a
27 changed files with 277 additions and 64 deletions

View File

@ -42,6 +42,11 @@ TBC..
## User Guide ## User Guide
TBC TBC
### Development Environment
If you aren't already using [Pycharm](https://www.jetbrains.com/pycharm/) - now is the time to try it out. This project was built in Pycharm and comes with
an IDEA project for you to get started with. I've done all of my development on this repo in this IDE and lean heavily
on its incredible debugger. It's free. Try it out. You won't be sorry.
### Dataset Preparation ### Dataset Preparation
DLAS comes with some Dataset instances that I have created for my own use. Unless you want to use one of the recipes above, you'll need to provide your own. Here is how to add your own Dataset: DLAS comes with some Dataset instances that I have created for my own use. Unless you want to use one of the recipes above, you'll need to provide your own. Here is how to add your own Dataset:

View File

@ -20,12 +20,15 @@ This directory contains several reference datasets which I have used in building
1. MultiframeDataset - Similar to SingleImageDataset, but infers a temporal relationship between images based on their 1. MultiframeDataset - Similar to SingleImageDataset, but infers a temporal relationship between images based on their
filenames: the last 12 characters before the file extension are assumed to be a frame counter. Images from this filenames: the last 12 characters before the file extension are assumed to be a frame counter. Images from this
dataset are grouped together with a temporal dimension for working with video data. dataset are grouped together with a temporal dimension for working with video data.
1. ImageFolderDataset - Reads raw images from a folder and feeds them into the model. Capable of performing corruptions
on those images like the above.
1. MultiscaleDataset - Reads full images from a directory and builds a tree of images constructed by cropping squares 1. MultiscaleDataset - Reads full images from a directory and builds a tree of images constructed by cropping squares
from the source image and resizing them to the target size recursively until the native resolution is hit. Each from the source image and resizing them to the target size recursively until the native resolution is hit. Each
recursive step decreases the crop size by a factor of 2. recursive step decreases the crop size by a factor of 2.
1. TorchDataset - A wrapper for miscellaneous pytorch datasets (e.g. MNIST, CIFAR, etc) which extracts the images
and reformats them in a way that the DLAS trainer understands.
1. FullImageDataset - An image patch dataset where the patches are dynamically extracted from full-size images. I have 1. FullImageDataset - An image patch dataset where the patches are dynamically extracted from full-size images. I have
generally stopped using this for performance reasons in favor of SingleImageDataset but it is useful for validation generally stopped using this for performance reasons and it should be considered deprecated.
and test so I keep it around.
## Information about the "chunked" format ## Information about the "chunked" format

View File

@ -323,7 +323,7 @@ class FullImageDataset(data.Dataset):
gt_fullsize_ref = torch.cat([gt_fullsize_ref, gt_mask], dim=0) gt_fullsize_ref = torch.cat([gt_fullsize_ref, gt_mask], dim=0)
lq_fullsize_ref = torch.cat([lq_fullsize_ref, lq_mask], dim=0) lq_fullsize_ref = torch.cat([lq_fullsize_ref, lq_mask], dim=0)
d = {'LQ': img_LQ, 'GT': img_GT, 'gt_fullsize_ref': gt_fullsize_ref, 'lq_fullsize_ref': lq_fullsize_ref, d = {'lq': img_LQ, 'hq': img_GT, 'gt_fullsize_ref': gt_fullsize_ref, 'lq_fullsize_ref': lq_fullsize_ref,
'lq_center': lq_center, 'gt_center': gt_center, 'lq_center': lq_center, 'gt_center': gt_center,
'LQ_path': LQ_path, 'GT_path': full_path} 'LQ_path': LQ_path, 'GT_path': full_path}
return d return d

View File

@ -9,7 +9,7 @@ from io import BytesIO
# options. # options.
class ImageCorruptor: class ImageCorruptor:
def __init__(self, opt): def __init__(self, opt):
self.fixed_corruptions = opt['fixed_corruptions'] self.fixed_corruptions = opt['fixed_corruptions'] if 'fixed_corruptions' in opt.keys() else []
self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 0 self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 0
if self.num_corrupts == 0: if self.num_corrupts == 0:
return return

View File

@ -29,7 +29,7 @@ class ImageFolderDataset:
self.weights = opt['weights'] self.weights = opt['weights']
# Just scan the given directory for images of standard types. # Just scan the given directory for images of standard types.
supported_types = ['jpg', 'JPG', 'jpeg', 'JPEG', 'png', 'PNG', 'gif', 'GIF'] supported_types = ['jpg', 'jpeg', 'png', 'gif']
self.image_paths = [] self.image_paths = []
for path, weight in zip(self.paths, self.weights): for path, weight in zip(self.paths, self.weights):
cache_path = os.path.join(path, 'cache.pth') cache_path = os.path.join(path, 'cache.pth')
@ -95,7 +95,7 @@ class ImageFolderDataset:
hq = torch.from_numpy(np.ascontiguousarray(np.transpose(hs[0], (2, 0, 1)))).float() hq = torch.from_numpy(np.ascontiguousarray(np.transpose(hs[0], (2, 0, 1)))).float()
lq = torch.from_numpy(np.ascontiguousarray(np.transpose(ls[0], (2, 0, 1)))).float() lq = torch.from_numpy(np.ascontiguousarray(np.transpose(ls[0], (2, 0, 1)))).float()
return {'LQ': lq, 'GT': hq, 'LQ_path': self.image_paths[item], 'GT_path': self.image_paths[item]} return {'lq': lq, 'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item]}
if __name__ == '__main__': if __name__ == '__main__':
@ -118,7 +118,7 @@ if __name__ == '__main__':
for i in range(0, len(ds)): for i in range(0, len(ds)):
o = ds[random.randint(0, len(ds))] o = ds[random.randint(0, len(ds))]
#for k, v in o.items(): #for k, v in o.items():
k = 'LQ' k = 'lq'
v = o[k] v = o[k]
#if 'LQ' in k and 'path' not in k and 'center' not in k: #if 'LQ' in k and 'path' not in k and 'center' not in k:
#if 'full' in k: #if 'full' in k:

View File

@ -57,7 +57,7 @@ class MultiFrameDataset(BaseUnsupervisedImageDataset):
lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).unsqueeze(dim=1) lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).unsqueeze(dim=1)
lq_ref = torch.cat([lq_ref, lq_mask], dim=1) lq_ref = torch.cat([lq_ref, lq_mask], dim=1)
return {'GT_path': path, 'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, return {'GT_path': path, 'lq': lq, 'hq': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)} 'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)}
@ -83,7 +83,7 @@ if __name__ == '__main__':
batch = None batch = None
for i in range(len(ds)): for i in range(len(ds)):
import random import random
k = 'LQ' k = 'lq'
element = ds[random.randint(0,len(ds))] element = ds[random.randint(0,len(ds))]
base_file = osp.basename(element["GT_path"]) base_file = osp.basename(element["GT_path"])
o = element[k].unsqueeze(0) o = element[k].unsqueeze(0)

View File

@ -87,7 +87,7 @@ class MultiScaleDataset(data.Dataset):
patches_lq = [torch.nn.functional.interpolate(p.unsqueeze(0), scale_factor=1/self.scale, mode='area').squeeze() for p in patches_hq_corrupted] patches_lq = [torch.nn.functional.interpolate(p.unsqueeze(0), scale_factor=1/self.scale, mode='area').squeeze() for p in patches_hq_corrupted]
patches_lq = torch.stack(patches_lq, dim=0) patches_lq = torch.stack(patches_lq, dim=0)
d = {'LQ': patches_lq, 'GT': patches_hq, 'GT_path': full_path} d = {'lq': patches_lq, 'hq': patches_hq, 'GT_path': full_path}
return d return d
def __len__(self): def __len__(self):

View File

@ -42,7 +42,7 @@ class PairedFrameDataset(BaseUnsupervisedImageDataset):
lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).squeeze().unsqueeze(dim=1) lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).squeeze().unsqueeze(dim=1)
lq_ref = torch.cat([lq_ref, lq_mask], dim=1) lq_ref = torch.cat([lq_ref, lq_mask], dim=1)
return {'GT_path': path, 'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, return {'GT_path': path, 'lq': lq, 'hq': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)} 'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)}
@ -68,7 +68,7 @@ if __name__ == '__main__':
batch = None batch = None
for i in range(len(ds)): for i in range(len(ds)):
import random import random
k = 'LQ' k = 'lq'
element = ds[random.randint(0,len(ds))] element = ds[random.randint(0,len(ds))]
base_file = osp.basename(element["GT_path"]) base_file = osp.basename(element["GT_path"])
o = element[k].unsqueeze(0) o = element[k].unsqueeze(0)

View File

@ -36,7 +36,7 @@ class SingleImageDataset(BaseUnsupervisedImageDataset):
lq_mask = torch.from_numpy(np.ascontiguousarray(lms[0])).unsqueeze(dim=0) lq_mask = torch.from_numpy(np.ascontiguousarray(lms[0])).unsqueeze(dim=0)
lq_ref = torch.cat([lq_ref, lq_mask], dim=0) lq_ref = torch.cat([lq_ref, lq_mask], dim=0)
return {'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, return {'lq': lq, 'hq': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
'lq_center': torch.tensor(lcs[0], dtype=torch.long), 'gt_center': torch.tensor(hcs[0], dtype=torch.long), 'lq_center': torch.tensor(lcs[0], dtype=torch.long), 'gt_center': torch.tensor(hcs[0], dtype=torch.long),
'LQ_path': path, 'GT_path': path} 'LQ_path': path, 'GT_path': path}
@ -62,7 +62,7 @@ if __name__ == '__main__':
for i in range(0, len(ds)): for i in range(0, len(ds)):
o = ds[random.randint(0, len(ds))] o = ds[random.randint(0, len(ds))]
#for k, v in o.items(): #for k, v in o.items():
k = 'LQ' k = 'lq'
v = o[k] v = o[k]
#if 'LQ' in k and 'path' not in k and 'center' not in k: #if 'LQ' in k and 'path' not in k and 'center' not in k:
#if 'full' in k: #if 'full' in k:

View File

@ -98,4 +98,4 @@ class Stylegan2Dataset(data.Dataset):
path = self.paths[index] path = self.paths[index]
img = Image.open(path) img = Image.open(path)
img = self.transform(img) img = self.transform(img)
return {'LQ': img, 'GT': img, 'GT_path': str(path)} return {'lq': img, 'hq': img, 'GT_path': str(path)}

View File

@ -24,7 +24,7 @@ class TorchDataset(Dataset):
def __getitem__(self, item): def __getitem__(self, item):
underlying_item = self.dataset[item][0] underlying_item = self.dataset[item][0]
return {'LQ': underlying_item, 'GT': underlying_item, return {'lq': underlying_item, 'hq': underlying_item,
'LQ_path': str(item), 'GT_path': str(item)} 'LQ_path': str(item), 'GT_path': str(item)}
def __len__(self): def __len__(self):

View File

@ -4,7 +4,6 @@ import os
import torch import torch
from torch.nn.parallel import DataParallel from torch.nn.parallel import DataParallel
import torch.nn as nn import torch.nn as nn
from apex.parallel import DistributedDataParallel
import models.lr_scheduler as lr_scheduler import models.lr_scheduler as lr_scheduler
import models.networks as networks import models.networks as networks
@ -106,6 +105,8 @@ class ExtensibleTrainer(BaseModel):
all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()] all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
for anet in all_networks: for anet in all_networks:
if opt['dist']: if opt['dist']:
# Use Apex to enable delay_allreduce, which is compatible with gradient checkpointing.
from apex.parallel import DistributedDataParallel
dnet = DistributedDataParallel(anet, delay_allreduce=True) dnet = DistributedDataParallel(anet, delay_allreduce=True)
else: else:
dnet = DataParallel(anet, device_ids=opt['gpu_ids']) dnet = DataParallel(anet, device_ids=opt['gpu_ids'])
@ -160,18 +161,9 @@ class ExtensibleTrainer(BaseModel):
o.zero_grad() o.zero_grad()
torch.cuda.empty_cache() torch.cuda.empty_cache()
self.lq = [t.to(self.device) for t in torch.chunk(data['LQ'], chunks=self.batch_factor, dim=0)] self.dstate = {}
if need_GT:
self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.batch_factor, dim=0)]
input_ref = data['ref'] if 'ref' in data.keys() else data['GT']
self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.batch_factor, dim=0)]
else:
self.hq = self.lq
self.ref = self.lq
self.dstate = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref}
for k, v in data.items(): for k, v in data.items():
if k not in ['LQ', 'ref', 'GT'] and isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=self.batch_factor, dim=0)] self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=self.batch_factor, dim=0)]
def optimize_parameters(self, step): def optimize_parameters(self, step):
@ -328,8 +320,8 @@ class ExtensibleTrainer(BaseModel):
def get_current_visuals(self, need_GT=True): def get_current_visuals(self, need_GT=True):
# Conforms to an archaic format from MMSR. # Conforms to an archaic format from MMSR.
return {'LQ': self.eval_state['lq'][0].float().cpu(), return {'lq': self.eval_state['lq'][0].float().cpu(),
'GT': self.eval_state['hq'][0].float().cpu(), 'hq': self.eval_state['hq'][0].float().cpu(),
'rlt': self.eval_state[self.opt['eval']['output_state']][0].float().cpu()} 'rlt': self.eval_state[self.opt['eval']['output_state']][0].float().cpu()}
def print_network(self): def print_network(self):

View File

@ -209,7 +209,7 @@ class SpineNet(nn.Module):
def __init__(self, def __init__(self,
arch, arch,
in_channels=3, in_channels=3,
output_level=[3, 4, 5, 6, 7], output_level=[3, 4],
conv_cfg=None, conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True), norm_cfg=dict(type='BN', requires_grad=True),
zero_init_residual=True, zero_init_residual=True,

View File

@ -30,8 +30,8 @@ class FlowGaussianNll(evaluator.Evaluator):
print("Evaluating FlowGaussianNll..") print("Evaluating FlowGaussianNll..")
for batch in tqdm(self.dataloader): for batch in tqdm(self.dataloader):
dev = self.env['device'] dev = self.env['device']
z, _, _ = self.model(gt=batch['GT'].to(dev), z, _, _ = self.model(gt=batch['hq'].to(dev),
lr=batch['LQ'].to(dev), lr=batch['lq'].to(dev),
epses=[], epses=[],
reverse=False, reverse=False,
add_gt_noise=False) add_gt_noise=False)

View File

@ -39,7 +39,7 @@ class SrStyleTransferEvaluator(evaluator.Evaluator):
counter = 0 counter = 0
for batch in self.sampler: for batch in self.sampler:
noise = torch.FloatTensor(self.batch_sz, 3, self.im_sz, self.im_sz).uniform_(0., 1.).to(self.env['device']) noise = torch.FloatTensor(self.batch_sz, 3, self.im_sz, self.im_sz).uniform_(0., 1.).to(self.env['device'])
batch_hq = [e['GT'] for e in batch] batch_hq = [e['hq'] for e in batch]
batch_hq = torch.stack(batch_hq, dim=0).to(self.env['device']) batch_hq = torch.stack(batch_hq, dim=0).to(self.env['device'])
resized_batch = torch.nn.functional.interpolate(batch_hq, scale_factor=1/self.scale, mode="area") resized_batch = torch.nn.functional.interpolate(batch_hq, scale_factor=1/self.scale, mode="area")
embedding = embedding_generator(resized_batch) embedding = embedding_generator(resized_batch)

View File

@ -66,9 +66,9 @@ class FeatureModel(BaseModel):
self.log_dict = OrderedDict() self.log_dict = OrderedDict()
def feed_data(self, data, need_GT=True): def feed_data(self, data, need_GT=True):
self.var_L = data['LQ'].to(self.device) # LQ self.var_L = data['lq'].to(self.device) # LQ
if need_GT: if need_GT:
self.real_H = data['GT'].to(self.device) # GT self.real_H = data['hq'].to(self.device) # GT
def optimize_parameters(self, step): def optimize_parameters(self, step):
self.optimizer_G.zero_grad() self.optimizer_G.zero_grad()

View File

@ -88,7 +88,7 @@ class FfmpegBackedVideoDataset(data.Dataset):
img_LQ = lq_template img_LQ = lq_template
ref = ref_template ref = ref_template
return {'LQ': img_LQ, 'lq_fullsize_ref': ref, return {'lq': img_LQ, 'lq_fullsize_ref': ref,
'lq_center': torch.tensor([img_LQ.shape[1] // 2, img_LQ.shape[2] // 2], dtype=torch.long) } 'lq_center': torch.tensor([img_LQ.shape[1] // 2, img_LQ.shape[2] // 2], dtype=torch.long) }
def __len__(self): def __len__(self):
@ -159,8 +159,8 @@ if __name__ == "__main__":
need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
if recurrent_mode and first_frame: if recurrent_mode and first_frame:
b, c, h, w = data['LQ'].shape b, c, h, w = data['lq'].shape
recurrent_entry = torch.zeros((b,c,h*scale,w*scale), device=data['LQ'].device) recurrent_entry = torch.zeros((b,c,h*scale,w*scale), device=data['lq'].device)
# Optionally swap out the 'generator' for the first frame to create a better image that the recurrent generator works off of. # Optionally swap out the 'generator' for the first frame to create a better image that the recurrent generator works off of.
if 'recurrent_hr_generator' in opt.keys(): if 'recurrent_hr_generator' in opt.keys():
recurrent_gen = model.env['generators']['generator'] recurrent_gen = model.env['generators']['generator']

View File

@ -0,0 +1,206 @@
import os
import shutil
import torch
import torch.nn as nn
import torchvision
from PIL import Image
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, Resize
from tqdm import tqdm
import numpy as np
from data.image_folder_dataset import ImageFolderDataset
from models.archs.spinenet_arch import SpineNet
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
# and the distance is computed across the channel dimension.
def structural_euc_dist(x, y):
diff = torch.square(x - y)
sum = torch.sum(diff, dim=-1)
return torch.sqrt(sum)
def cosine_similarity(x, y):
x = norm(x)
y = norm(y)
return -nn.CosineSimilarity()(x, y) # probably better to just use this class to perform the calc. Just left this here to remind myself.
def norm(x):
sh = x.shape
sh_r = tuple([sh[i] if i != len(sh)-1 else 1 for i in range(len(sh))])
return (x - torch.mean(x, dim=-1).reshape(sh_r)) / torch.std(x, dim=-1).reshape(sh_r)
def im_norm(x):
return (((x - torch.mean(x, dim=(2,3)).reshape(-1,1,1,1)) / torch.std(x, dim=(2,3)).reshape(-1,1,1,1)) * .5) + .5
def get_image_folder_dataloader(batch_size, num_workers):
dataset_opt = {
'name': 'amalgam',
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
'weights': [1],
'target_size': 512,
'force_multiple': 32,
'scale': 1
}
dataset = ImageFolderDataset(dataset_opt)
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
def create_latent_database(model):
batch_size = 8
num_workers = 1
output_path = '../../results/byol_spinenet_latents/'
os.makedirs(output_path, exist_ok=True)
dataloader = get_image_folder_dataloader(batch_size, num_workers)
id = 0
dict_count = 1
latent_dict = {}
all_paths = []
for batch in tqdm(dataloader):
hq = batch['hq'].to('cuda:1')
latent = model(hq)[1] # BYOL trainer only trains the '4' output, which is indexed at [1]. Confusing.
for b in range(latent.shape[0]):
im_path = batch['HQ_path'][b]
all_paths.append(im_path)
latent_dict[id] = latent[b].detach().cpu()
if (id+1) % 1000 == 0:
print("Saving checkpoint..")
torch.save(latent_dict, os.path.join(output_path, "latent_dict_%i.pth" % (dict_count,)))
latent_dict = {}
torch.save(all_paths, os.path.join(output_path, "all_paths.pth"))
dict_count += 1
id += 1
def _get_mins_from_latent_dictionary(latent, hq_img_repo, ld_file_name, batch_size):
_, c, h, w = latent.shape
lat_dict = torch.load(os.path.join(hq_img_repo, ld_file_name))
comparables = torch.stack(list(lat_dict.values()), dim=0).permute(0,2,3,1)
cbl_shape = comparables.shape[:3]
assert cbl_shape[1] == 32
comparables = comparables.reshape(-1, c)
clat = latent.reshape(1,-1,h*w).permute(2,0,1)
cpbl_chunked = torch.chunk(comparables, len(comparables) // batch_size)
assert len(comparables) % batch_size == 0 # The reconstruction logic doesn't work if this is not the case.
mins = []
min_offsets = []
for cpbl_chunk in tqdm(cpbl_chunked):
cpbl_chunk = cpbl_chunk.to('cuda:1')
dist = structural_euc_dist(clat, cpbl_chunk.unsqueeze(0))
_min = torch.min(dist, dim=-1)
mins.append(_min[0])
min_offsets.append(_min[1])
mins = torch.min(torch.stack(mins, dim=-1), dim=-1)
# There's some way to do this in torch, I just can't figure it out..
for i in range(len(mins[1])):
mins[1][i] = mins[1][i] * batch_size + min_offsets[mins[1][i]][i]
return mins[0].cpu(), mins[1].cpu(), len(comparables)
def find_similar_latents(model):
img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\adrianna_xx.jpg'
#img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\nicky_xx.jpg'
hq_img_repo = '../../results/byol_spinenet_latents'
output_path = '../../results/byol_spinenet_similars'
batch_size = 1024
num_maps = 8
os.makedirs(output_path, exist_ok=True)
img_bank_paths = torch.load(os.path.join(hq_img_repo, "all_paths.pth"))
img_t = ToTensor()(Image.open(img)).to('cuda:1').unsqueeze(0)
_, _, h, w = img_t.shape
img_t = img_t[:, :, :128*(h//128), :128*(w//128)]
latent = model(img_t)[1]
_, c, h, w = latent.shape
mins, min_offsets = [], []
total_latents = -1
for d_id in range(1,num_maps+1):
mn, of, tl = _get_mins_from_latent_dictionary(latent, hq_img_repo, "latent_dict_%i.pth" % (d_id), batch_size)
if total_latents != -1:
assert total_latents == tl
else:
total_latents = tl
mins.append(mn)
min_offsets.append(of)
mins = torch.min(torch.stack(mins, dim=-1), dim=-1)
# There's some way to do this in torch, I just can't figure it out..
for i in range(len(mins[1])):
mins[1][i] = mins[1][i] * total_latents + min_offsets[mins[1][i]][i]
min_ids = mins[1]
print("Constructing image map..")
doc_out = '''
<html><body><img id="imgmap" src="source.png" usemap="#map">
<map name="map">%s</map><br>
<button onclick="if(imgmap.src.includes('output.png')){imgmap.src='source.png';}else{imgmap.src='output.png';}">Swap Images</button>
</body></html>
'''
img_map_areas = []
img_out = torch.zeros((1,3,h*16,w*16))
for i, ind in enumerate(tqdm(min_ids)):
u = np.unravel_index(ind.item(), (num_maps*total_latents//(32*32),32,32))
h_, w_ = np.unravel_index(i, (h, w))
img = ToTensor()(Resize((512, 512))(Image.open(img_bank_paths[u[0]])))
t = 16 * u[1]
l = 16 * u[2]
patch = img[:, t:t+16, l:l+16]
img_out[:,:,h_*16:h_*16+16,w_*16:w_*16+16] = patch
# Also save the image with a masked map
mask = torch.full_like(img, fill_value=.3)
mask[:, t:t+16, l:l+16] = 1
masked_img = img * mask
masked_src_img_output_file = os.path.join(output_path, "%i_%i__%i.png" % (t, l, u[0]))
torchvision.utils.save_image(masked_img, masked_src_img_output_file)
# Update the image map areas.
img_map_areas.append('<area shape="rect" coords="%i,%i,%i,%i" href="%s">' % (w_*16,h_*16,w_*16+16,h_*16+16,masked_src_img_output_file))
torchvision.utils.save_image(img_out, os.path.join(output_path, "output.png"))
torchvision.utils.save_image(img_t, os.path.join(output_path, "source.png"))
doc_out = doc_out % ('\n'.join(img_map_areas))
with open(os.path.join(output_path, 'map.html'), 'w') as f:
print(doc_out, file=f)
def explore_latent_results(model):
batch_size = 16
num_workers = 1
output_path = '../../results/byol_spinenet_explore_latents/'
os.makedirs(output_path, exist_ok=True)
dataloader = get_image_folder_dataloader(batch_size, num_workers)
id = 0
for batch in tqdm(dataloader):
hq = batch['hq'].to('cuda:1')
latent = model(hq)[1] # BYOL trainer only trains the '4' output, which is indexed at [1]. Confusing.
# This operation works by computing the distance of every structural index from the center and using that
# as a "heatmap".
b, c, h, w = latent.shape
center = latent[:, :, h//2, w//2].unsqueeze(-1).unsqueeze(-1)
centers = center.repeat(1, 1, h, w)
dist = cosine_similarity(latent, centers).unsqueeze(1)
dist = im_norm(dist)
torchvision.utils.save_image(dist, os.path.join(output_path, "%i.png" % id))
id += 1
if __name__ == '__main__':
pretrained_path = '../../experiments/spinenet49_imgset_byol.pth'
model = SpineNet('49', in_channels=3, use_input_norm=True).to('cuda:1')
model.load_state_dict(torch.load(pretrained_path), strict=True)
model.eval()
with torch.no_grad():
find_similar_latents(model)

View File

@ -40,8 +40,8 @@ if __name__ == '__main__':
break break
sampled += 1 sampled += 1
im = rgb2ycbcr(train_data['GT'].double()) im = rgb2ycbcr(train_data['hq'].double())
im_LR = rgb2ycbcr(F.interpolate(train_data['LQ'].double(), im_LR = rgb2ycbcr(F.interpolate(train_data['lq'].double(),
size=im.shape[2:], size=im.shape[2:],
mode="bicubic", align_corners=False)) mode="bicubic", align_corners=False))
patches_hr = extract_patches_2d(img=im, patch_shape=(patch_size,patch_size), batch_first=True) patches_hr = extract_patches_2d(img=im, patch_shape=(patch_size,patch_size), batch_first=True)

View File

@ -16,7 +16,7 @@ import utils.util as util # noqa: E402
def main(): def main():
dataset = 'DIV2K_demo' # vimeo90K | REDS | general (e.g., DIV2K, 291) | DIV2K_demo |test dataset = 'DIV2K_demo' # vimeo90K | REDS | general (e.g., DIV2K, 291) | DIV2K_demo |test
mode = 'GT' # used for vimeo90k and REDS datasets mode = 'hq' # used for vimeo90k and REDS datasets
# vimeo90k: GT | LR | flow # vimeo90k: GT | LR | flow
# REDS: train_sharp, train_sharp_bicubic, train_blur_bicubic, train_blur, train_blur_comp # REDS: train_sharp, train_sharp_bicubic, train_blur_bicubic, train_blur, train_blur_comp
# train_sharp_flowx4 # train_sharp_flowx4
@ -159,7 +159,7 @@ def vimeo90k(mode):
read_all_imgs = False # whether real all images to memory with multiprocessing read_all_imgs = False # whether real all images to memory with multiprocessing
# Set False for use limited memory # Set False for use limited memory
BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False
if mode == 'GT': if mode == 'hq':
img_folder = '../../datasets/vimeo90k/vimeo_septuplet/sequences' img_folder = '../../datasets/vimeo90k/vimeo_septuplet/sequences'
lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb' lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb'
txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt' txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
@ -204,7 +204,7 @@ def vimeo90k(mode):
keys.append('{}_{}_{}'.format(folder, sub_folder, j + 1)) keys.append('{}_{}_{}'.format(folder, sub_folder, j + 1))
all_img_list = sorted(all_img_list) all_img_list = sorted(all_img_list)
keys = sorted(keys) keys = sorted(keys)
if mode == 'GT': # only read the 4th frame for the GT mode if mode == 'hq': # only read the 4th frame for the GT mode
print('Only keep the 4th frame.') print('Only keep the 4th frame.')
all_img_list = [v for v in all_img_list if v.endswith('im4.png')] all_img_list = [v for v in all_img_list if v.endswith('im4.png')]
keys = [v for v in keys if v.endswith('_4')] keys = [v for v in keys if v.endswith('_4')]
@ -255,9 +255,9 @@ def vimeo90k(mode):
#### create meta information #### create meta information
meta_info = {} meta_info = {}
if mode == 'GT': if mode == 'hq':
meta_info['name'] = 'Vimeo90K_train_GT' meta_info['name'] = 'Vimeo90K_train_GT'
elif mode == 'LR': elif mode == 'lq':
meta_info['name'] = 'Vimeo90K_train_LR' meta_info['name'] = 'Vimeo90K_train_LR'
elif mode == 'flow': elif mode == 'flow':
meta_info['name'] = 'Vimeo90K_train_flowx4' meta_info['name'] = 'Vimeo90K_train_flowx4'

View File

@ -163,7 +163,7 @@ if __name__ == "__main__":
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
srg_analyze = False srg_analyze = False
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../experiments/train_exd_imgset_srflow/train_exd_imgset_srflow.yml') parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/train_exd_imgsetext_srflow_frompsnr.yml')
opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt) opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt utils.util.loaded_options = opt
@ -180,10 +180,10 @@ if __name__ == "__main__":
gen = model.networks['generator'] gen = model.networks['generator']
gen.eval() gen.eval()
mode = "temperature" # temperature | restore | latent_transfer | feed_through mode = "restore" # temperature | restore | latent_transfer | feed_through
#imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\val2\\lr\\*" #imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\val2\\lr\\*"
#imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\*" imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\*"
imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half\\*lanette*" #imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half\\*lanette*"
scale = 2 scale = 2
resample_factor = 1 # When != 1, the HR image is upsampled by this factor using a bicubic to get the local latents. resample_factor = 1 # When != 1, the HR image is upsampled by this factor using a bicubic to get the local latents.
temperature = 1 temperature = 1
@ -224,6 +224,11 @@ if __name__ == "__main__":
t = image_2_tensor(img_file).to(model.env['device']) t = image_2_tensor(img_file).to(model.env['device'])
if resample_factor != 1: if resample_factor != 1:
t = F.interpolate(t, scale_factor=resample_factor, mode="bicubic") t = F.interpolate(t, scale_factor=resample_factor, mode="bicubic")
# Ensure the input image is a factor of 16.
_, _, h, w = t.shape
h = 16 * (h // 16)
w = 16 * (w // 16)
t = t[:, :, :h, :w]
resample_img = t resample_img = t
# Fetch the latent metrics & latents for each image we are resampling. # Fetch the latent metrics & latents for each image we are resampling.
@ -255,6 +260,6 @@ if __name__ == "__main__":
for j in range(len(lats)): for j in range(len(lats)):
path = os.path.join(output_path, "%i_%i" % (im_it, j)) path = os.path.join(output_path, "%i_%i" % (im_it, j))
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
torchvision.utils.save_image(resample_img, os.path.join(path, "%i_orig.jpg" %(im_it))) torchvision.utils.save_image(resample_img, os.path.join(path, "orig.jpg" %(im_it)))
create_interpolation_video(gen, F.interpolate(resample_img, scale_factor=1/scale, mode="area"), create_interpolation_video(gen, F.interpolate(resample_img, scale_factor=1/scale, mode="area"),
path, [torch.zeros_like(l) for l in lats[j]], lats[j]) path, [torch.zeros_like(l) for l in lats[j]], lats[j])

View File

@ -85,8 +85,8 @@ def main():
if dataset == 'REDS' or dataset == 'Vimeo90K': if dataset == 'REDS' or dataset == 'Vimeo90K':
LQs = data['LQs'] LQs = data['LQs']
else: else:
LQ = data['LQ'] LQ = data['lq']
GT = data['GT'] GT = data['hq']
if dataset == 'REDS' or dataset == 'Vimeo90K': if dataset == 'REDS' or dataset == 'Vimeo90K':
for j in range(LQs.size(1)): for j in range(LQs.size(1)):

View File

@ -68,6 +68,6 @@ if __name__ == "__main__":
# removed += 1 # removed += 1
imname = osp.basename(data['GT_path'][i]) imname = osp.basename(data['GT_path'][i])
if results[i]-dataset_mean > 1: if results[i]-dataset_mean > 1:
torchvision.utils.save_image(data['GT'][i], osp.join(bin_path, imname)) torchvision.utils.save_image(data['hq'][i], osp.join(bin_path, imname))
print("Removed %i/%i images" % (removed, len(test_set))) print("Removed %i/%i images" % (removed, len(test_set)))

View File

@ -66,7 +66,7 @@ if __name__ == "__main__":
model.test() model.test()
gen = model.eval_state['gen'][0].to(model.env['device']) gen = model.eval_state['gen'][0].to(model.env['device'])
feagen = netF(gen) feagen = netF(gen)
feareal = netF(data['GT'].to(model.env['device'])) feareal = netF(data['hq'].to(model.env['device']))
losses = torch.sum(torch.abs(feareal - feagen), dim=(1,2,3)) losses = torch.sum(torch.abs(feareal - feagen), dim=(1,2,3))
means.append(torch.mean(losses).item()) means.append(torch.mean(losses).item())
#print(sum(means)/len(means), torch.mean(losses), torch.max(losses), torch.min(losses)) #print(sum(means)/len(means), torch.mean(losses), torch.max(losses), torch.min(losses))
@ -76,6 +76,6 @@ if __name__ == "__main__":
removed += 1 removed += 1
#imname = osp.basename(data['GT_path'][i]) #imname = osp.basename(data['GT_path'][i])
#if losses[i] < 25000: #if losses[i] < 25000:
# torchvision.utils.save_image(data['GT'][i], osp.join(bin_path, imname)) # torchvision.utils.save_image(data['hq'][i], osp.join(bin_path, imname))
print("Removed %i/%i images" % (removed, len(test_set))) print("Removed %i/%i images" % (removed, len(test_set)))

View File

@ -41,9 +41,9 @@ def forward_pass(model, output_dir, alteration_suffix=''):
save_img_path = osp.join(output_dir, img_name + '.png') save_img_path = osp.join(output_dir, img_name + '.png')
if need_GT: if need_GT:
fea_loss += model.compute_fea_loss(visuals[i], data['GT'][i]) fea_loss += model.compute_fea_loss(visuals[i], data['hq'][i])
psnr_sr = util.tensor2img(visuals[i]) psnr_sr = util.tensor2img(visuals[i])
psnr_gt = util.tensor2img(data['GT'][i]) psnr_gt = util.tensor2img(data['hq'][i])
psnr_loss += util.calculate_psnr(psnr_sr, psnr_gt) psnr_loss += util.calculate_psnr(psnr_sr, psnr_gt)
util.save_img(sr_img, save_img_path) util.save_img(sr_img, save_img_path)

View File

@ -231,13 +231,13 @@ class Trainer:
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8 sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
# calculate PSNR # calculate PSNR
if self.val_compute_psnr: if self.val_compute_psnr:
gt_img = util.tensor2img(visuals['GT'][b]) # uint8 gt_img = util.tensor2img(visuals['hq'][b]) # uint8
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
avg_psnr += util.calculate_psnr(sr_img, gt_img) avg_psnr += util.calculate_psnr(sr_img, gt_img)
# calculate fea loss # calculate fea loss
if self.val_compute_fea: if self.val_compute_fea:
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b]) avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['hq'][b])
# Save SR images for reference # Save SR images for reference
img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step) img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step)

View File

@ -3,6 +3,8 @@ import math
import argparse import argparse
import random import random
import logging import logging
import torchvision
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@ -231,18 +233,18 @@ class Trainer:
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8 sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
# calculate PSNR # calculate PSNR
if self.val_compute_psnr: if self.val_compute_psnr:
gt_img = util.tensor2img(visuals['GT'][b]) # uint8 gt_img = util.tensor2img(visuals['hq'][b]) # uint8
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
avg_psnr += util.calculate_psnr(sr_img, gt_img) avg_psnr += util.calculate_psnr(sr_img, gt_img)
# calculate fea loss # calculate fea loss
if self.val_compute_fea: if self.val_compute_fea:
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b]) avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['hq'][b])
# Save SR images for reference # Save SR images for reference
img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step) img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step)
save_img_path = os.path.join(img_dir, img_base_name) save_img_path = os.path.join(img_dir, img_base_name)
util.save_img(sr_img, save_img_path) torchvision.utils.save_image(visuals['rlt'], save_img_path)
avg_psnr = avg_psnr / idx avg_psnr = avg_psnr / idx
avg_fea_loss = avg_fea_loss / idx avg_fea_loss = avg_fea_loss / idx
@ -291,7 +293,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_srflow_frompsnr.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_srflow_bigboi_frompsnr.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()