Merge remote-tracking branch 'origin/gan_lab' into gan_lab

This commit is contained in:
James Betker 2020-12-07 12:51:04 -07:00
commit bca59ed98a
27 changed files with 277 additions and 64 deletions

View File

@ -42,6 +42,11 @@ TBC..
## User Guide
TBC
### Development Environment
If you aren't already using [Pycharm](https://www.jetbrains.com/pycharm/) - now is the time to try it out. This project was built in Pycharm and comes with
an IDEA project for you to get started with. I've done all of my development on this repo in this IDE and lean heavily
on its incredible debugger. It's free. Try it out. You won't be sorry.
### Dataset Preparation
DLAS comes with some Dataset instances that I have created for my own use. Unless you want to use one of the recipes above, you'll need to provide your own. Here is how to add your own Dataset:

View File

@ -20,12 +20,15 @@ This directory contains several reference datasets which I have used in building
1. MultiframeDataset - Similar to SingleImageDataset, but infers a temporal relationship between images based on their
filenames: the last 12 characters before the file extension are assumed to be a frame counter. Images from this
dataset are grouped together with a temporal dimension for working with video data.
1. ImageFolderDataset - Reads raw images from a folder and feeds them into the model. Capable of performing corruptions
on those images like the above.
1. MultiscaleDataset - Reads full images from a directory and builds a tree of images constructed by cropping squares
from the source image and resizing them to the target size recursively until the native resolution is hit. Each
recursive step decreases the crop size by a factor of 2.
1. TorchDataset - A wrapper for miscellaneous pytorch datasets (e.g. MNIST, CIFAR, etc) which extracts the images
and reformats them in a way that the DLAS trainer understands.
1. FullImageDataset - An image patch dataset where the patches are dynamically extracted from full-size images. I have
generally stopped using this for performance reasons in favor of SingleImageDataset but it is useful for validation
and test so I keep it around.
generally stopped using this for performance reasons and it should be considered deprecated.
## Information about the "chunked" format

View File

@ -323,7 +323,7 @@ class FullImageDataset(data.Dataset):
gt_fullsize_ref = torch.cat([gt_fullsize_ref, gt_mask], dim=0)
lq_fullsize_ref = torch.cat([lq_fullsize_ref, lq_mask], dim=0)
d = {'LQ': img_LQ, 'GT': img_GT, 'gt_fullsize_ref': gt_fullsize_ref, 'lq_fullsize_ref': lq_fullsize_ref,
d = {'lq': img_LQ, 'hq': img_GT, 'gt_fullsize_ref': gt_fullsize_ref, 'lq_fullsize_ref': lq_fullsize_ref,
'lq_center': lq_center, 'gt_center': gt_center,
'LQ_path': LQ_path, 'GT_path': full_path}
return d

View File

@ -9,7 +9,7 @@ from io import BytesIO
# options.
class ImageCorruptor:
def __init__(self, opt):
self.fixed_corruptions = opt['fixed_corruptions']
self.fixed_corruptions = opt['fixed_corruptions'] if 'fixed_corruptions' in opt.keys() else []
self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 0
if self.num_corrupts == 0:
return

View File

@ -29,7 +29,7 @@ class ImageFolderDataset:
self.weights = opt['weights']
# Just scan the given directory for images of standard types.
supported_types = ['jpg', 'JPG', 'jpeg', 'JPEG', 'png', 'PNG', 'gif', 'GIF']
supported_types = ['jpg', 'jpeg', 'png', 'gif']
self.image_paths = []
for path, weight in zip(self.paths, self.weights):
cache_path = os.path.join(path, 'cache.pth')
@ -95,7 +95,7 @@ class ImageFolderDataset:
hq = torch.from_numpy(np.ascontiguousarray(np.transpose(hs[0], (2, 0, 1)))).float()
lq = torch.from_numpy(np.ascontiguousarray(np.transpose(ls[0], (2, 0, 1)))).float()
return {'LQ': lq, 'GT': hq, 'LQ_path': self.image_paths[item], 'GT_path': self.image_paths[item]}
return {'lq': lq, 'hq': hq, 'LQ_path': self.image_paths[item], 'HQ_path': self.image_paths[item]}
if __name__ == '__main__':
@ -118,7 +118,7 @@ if __name__ == '__main__':
for i in range(0, len(ds)):
o = ds[random.randint(0, len(ds))]
#for k, v in o.items():
k = 'LQ'
k = 'lq'
v = o[k]
#if 'LQ' in k and 'path' not in k and 'center' not in k:
#if 'full' in k:

View File

@ -57,7 +57,7 @@ class MultiFrameDataset(BaseUnsupervisedImageDataset):
lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).unsqueeze(dim=1)
lq_ref = torch.cat([lq_ref, lq_mask], dim=1)
return {'GT_path': path, 'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
return {'GT_path': path, 'lq': lq, 'hq': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)}
@ -83,7 +83,7 @@ if __name__ == '__main__':
batch = None
for i in range(len(ds)):
import random
k = 'LQ'
k = 'lq'
element = ds[random.randint(0,len(ds))]
base_file = osp.basename(element["GT_path"])
o = element[k].unsqueeze(0)

View File

@ -87,7 +87,7 @@ class MultiScaleDataset(data.Dataset):
patches_lq = [torch.nn.functional.interpolate(p.unsqueeze(0), scale_factor=1/self.scale, mode='area').squeeze() for p in patches_hq_corrupted]
patches_lq = torch.stack(patches_lq, dim=0)
d = {'LQ': patches_lq, 'GT': patches_hq, 'GT_path': full_path}
d = {'lq': patches_lq, 'hq': patches_hq, 'GT_path': full_path}
return d
def __len__(self):

View File

@ -42,7 +42,7 @@ class PairedFrameDataset(BaseUnsupervisedImageDataset):
lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).squeeze().unsqueeze(dim=1)
lq_ref = torch.cat([lq_ref, lq_mask], dim=1)
return {'GT_path': path, 'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
return {'GT_path': path, 'lq': lq, 'hq': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)}
@ -68,7 +68,7 @@ if __name__ == '__main__':
batch = None
for i in range(len(ds)):
import random
k = 'LQ'
k = 'lq'
element = ds[random.randint(0,len(ds))]
base_file = osp.basename(element["GT_path"])
o = element[k].unsqueeze(0)

View File

@ -36,7 +36,7 @@ class SingleImageDataset(BaseUnsupervisedImageDataset):
lq_mask = torch.from_numpy(np.ascontiguousarray(lms[0])).unsqueeze(dim=0)
lq_ref = torch.cat([lq_ref, lq_mask], dim=0)
return {'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
return {'lq': lq, 'hq': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
'lq_center': torch.tensor(lcs[0], dtype=torch.long), 'gt_center': torch.tensor(hcs[0], dtype=torch.long),
'LQ_path': path, 'GT_path': path}
@ -62,7 +62,7 @@ if __name__ == '__main__':
for i in range(0, len(ds)):
o = ds[random.randint(0, len(ds))]
#for k, v in o.items():
k = 'LQ'
k = 'lq'
v = o[k]
#if 'LQ' in k and 'path' not in k and 'center' not in k:
#if 'full' in k:

View File

@ -98,4 +98,4 @@ class Stylegan2Dataset(data.Dataset):
path = self.paths[index]
img = Image.open(path)
img = self.transform(img)
return {'LQ': img, 'GT': img, 'GT_path': str(path)}
return {'lq': img, 'hq': img, 'GT_path': str(path)}

View File

@ -24,7 +24,7 @@ class TorchDataset(Dataset):
def __getitem__(self, item):
underlying_item = self.dataset[item][0]
return {'LQ': underlying_item, 'GT': underlying_item,
return {'lq': underlying_item, 'hq': underlying_item,
'LQ_path': str(item), 'GT_path': str(item)}
def __len__(self):

View File

@ -4,7 +4,6 @@ import os
import torch
from torch.nn.parallel import DataParallel
import torch.nn as nn
from apex.parallel import DistributedDataParallel
import models.lr_scheduler as lr_scheduler
import models.networks as networks
@ -106,6 +105,8 @@ class ExtensibleTrainer(BaseModel):
all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
for anet in all_networks:
if opt['dist']:
# Use Apex to enable delay_allreduce, which is compatible with gradient checkpointing.
from apex.parallel import DistributedDataParallel
dnet = DistributedDataParallel(anet, delay_allreduce=True)
else:
dnet = DataParallel(anet, device_ids=opt['gpu_ids'])
@ -160,18 +161,9 @@ class ExtensibleTrainer(BaseModel):
o.zero_grad()
torch.cuda.empty_cache()
self.lq = [t.to(self.device) for t in torch.chunk(data['LQ'], chunks=self.batch_factor, dim=0)]
if need_GT:
self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.batch_factor, dim=0)]
input_ref = data['ref'] if 'ref' in data.keys() else data['GT']
self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.batch_factor, dim=0)]
else:
self.hq = self.lq
self.ref = self.lq
self.dstate = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref}
self.dstate = {}
for k, v in data.items():
if k not in ['LQ', 'ref', 'GT'] and isinstance(v, torch.Tensor):
if isinstance(v, torch.Tensor):
self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=self.batch_factor, dim=0)]
def optimize_parameters(self, step):
@ -328,8 +320,8 @@ class ExtensibleTrainer(BaseModel):
def get_current_visuals(self, need_GT=True):
# Conforms to an archaic format from MMSR.
return {'LQ': self.eval_state['lq'][0].float().cpu(),
'GT': self.eval_state['hq'][0].float().cpu(),
return {'lq': self.eval_state['lq'][0].float().cpu(),
'hq': self.eval_state['hq'][0].float().cpu(),
'rlt': self.eval_state[self.opt['eval']['output_state']][0].float().cpu()}
def print_network(self):

View File

@ -209,7 +209,7 @@ class SpineNet(nn.Module):
def __init__(self,
arch,
in_channels=3,
output_level=[3, 4, 5, 6, 7],
output_level=[3, 4],
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
zero_init_residual=True,

View File

@ -30,8 +30,8 @@ class FlowGaussianNll(evaluator.Evaluator):
print("Evaluating FlowGaussianNll..")
for batch in tqdm(self.dataloader):
dev = self.env['device']
z, _, _ = self.model(gt=batch['GT'].to(dev),
lr=batch['LQ'].to(dev),
z, _, _ = self.model(gt=batch['hq'].to(dev),
lr=batch['lq'].to(dev),
epses=[],
reverse=False,
add_gt_noise=False)

View File

@ -39,7 +39,7 @@ class SrStyleTransferEvaluator(evaluator.Evaluator):
counter = 0
for batch in self.sampler:
noise = torch.FloatTensor(self.batch_sz, 3, self.im_sz, self.im_sz).uniform_(0., 1.).to(self.env['device'])
batch_hq = [e['GT'] for e in batch]
batch_hq = [e['hq'] for e in batch]
batch_hq = torch.stack(batch_hq, dim=0).to(self.env['device'])
resized_batch = torch.nn.functional.interpolate(batch_hq, scale_factor=1/self.scale, mode="area")
embedding = embedding_generator(resized_batch)

View File

@ -66,9 +66,9 @@ class FeatureModel(BaseModel):
self.log_dict = OrderedDict()
def feed_data(self, data, need_GT=True):
self.var_L = data['LQ'].to(self.device) # LQ
self.var_L = data['lq'].to(self.device) # LQ
if need_GT:
self.real_H = data['GT'].to(self.device) # GT
self.real_H = data['hq'].to(self.device) # GT
def optimize_parameters(self, step):
self.optimizer_G.zero_grad()

View File

@ -88,7 +88,7 @@ class FfmpegBackedVideoDataset(data.Dataset):
img_LQ = lq_template
ref = ref_template
return {'LQ': img_LQ, 'lq_fullsize_ref': ref,
return {'lq': img_LQ, 'lq_fullsize_ref': ref,
'lq_center': torch.tensor([img_LQ.shape[1] // 2, img_LQ.shape[2] // 2], dtype=torch.long) }
def __len__(self):
@ -159,8 +159,8 @@ if __name__ == "__main__":
need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
if recurrent_mode and first_frame:
b, c, h, w = data['LQ'].shape
recurrent_entry = torch.zeros((b,c,h*scale,w*scale), device=data['LQ'].device)
b, c, h, w = data['lq'].shape
recurrent_entry = torch.zeros((b,c,h*scale,w*scale), device=data['lq'].device)
# Optionally swap out the 'generator' for the first frame to create a better image that the recurrent generator works off of.
if 'recurrent_hr_generator' in opt.keys():
recurrent_gen = model.env['generators']['generator']

View File

@ -0,0 +1,206 @@
import os
import shutil
import torch
import torch.nn as nn
import torchvision
from PIL import Image
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, Resize
from tqdm import tqdm
import numpy as np
from data.image_folder_dataset import ImageFolderDataset
from models.archs.spinenet_arch import SpineNet
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
# and the distance is computed across the channel dimension.
def structural_euc_dist(x, y):
diff = torch.square(x - y)
sum = torch.sum(diff, dim=-1)
return torch.sqrt(sum)
def cosine_similarity(x, y):
x = norm(x)
y = norm(y)
return -nn.CosineSimilarity()(x, y) # probably better to just use this class to perform the calc. Just left this here to remind myself.
def norm(x):
sh = x.shape
sh_r = tuple([sh[i] if i != len(sh)-1 else 1 for i in range(len(sh))])
return (x - torch.mean(x, dim=-1).reshape(sh_r)) / torch.std(x, dim=-1).reshape(sh_r)
def im_norm(x):
return (((x - torch.mean(x, dim=(2,3)).reshape(-1,1,1,1)) / torch.std(x, dim=(2,3)).reshape(-1,1,1,1)) * .5) + .5
def get_image_folder_dataloader(batch_size, num_workers):
dataset_opt = {
'name': 'amalgam',
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
'weights': [1],
'target_size': 512,
'force_multiple': 32,
'scale': 1
}
dataset = ImageFolderDataset(dataset_opt)
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
def create_latent_database(model):
batch_size = 8
num_workers = 1
output_path = '../../results/byol_spinenet_latents/'
os.makedirs(output_path, exist_ok=True)
dataloader = get_image_folder_dataloader(batch_size, num_workers)
id = 0
dict_count = 1
latent_dict = {}
all_paths = []
for batch in tqdm(dataloader):
hq = batch['hq'].to('cuda:1')
latent = model(hq)[1] # BYOL trainer only trains the '4' output, which is indexed at [1]. Confusing.
for b in range(latent.shape[0]):
im_path = batch['HQ_path'][b]
all_paths.append(im_path)
latent_dict[id] = latent[b].detach().cpu()
if (id+1) % 1000 == 0:
print("Saving checkpoint..")
torch.save(latent_dict, os.path.join(output_path, "latent_dict_%i.pth" % (dict_count,)))
latent_dict = {}
torch.save(all_paths, os.path.join(output_path, "all_paths.pth"))
dict_count += 1
id += 1
def _get_mins_from_latent_dictionary(latent, hq_img_repo, ld_file_name, batch_size):
_, c, h, w = latent.shape
lat_dict = torch.load(os.path.join(hq_img_repo, ld_file_name))
comparables = torch.stack(list(lat_dict.values()), dim=0).permute(0,2,3,1)
cbl_shape = comparables.shape[:3]
assert cbl_shape[1] == 32
comparables = comparables.reshape(-1, c)
clat = latent.reshape(1,-1,h*w).permute(2,0,1)
cpbl_chunked = torch.chunk(comparables, len(comparables) // batch_size)
assert len(comparables) % batch_size == 0 # The reconstruction logic doesn't work if this is not the case.
mins = []
min_offsets = []
for cpbl_chunk in tqdm(cpbl_chunked):
cpbl_chunk = cpbl_chunk.to('cuda:1')
dist = structural_euc_dist(clat, cpbl_chunk.unsqueeze(0))
_min = torch.min(dist, dim=-1)
mins.append(_min[0])
min_offsets.append(_min[1])
mins = torch.min(torch.stack(mins, dim=-1), dim=-1)
# There's some way to do this in torch, I just can't figure it out..
for i in range(len(mins[1])):
mins[1][i] = mins[1][i] * batch_size + min_offsets[mins[1][i]][i]
return mins[0].cpu(), mins[1].cpu(), len(comparables)
def find_similar_latents(model):
img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\adrianna_xx.jpg'
#img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\nicky_xx.jpg'
hq_img_repo = '../../results/byol_spinenet_latents'
output_path = '../../results/byol_spinenet_similars'
batch_size = 1024
num_maps = 8
os.makedirs(output_path, exist_ok=True)
img_bank_paths = torch.load(os.path.join(hq_img_repo, "all_paths.pth"))
img_t = ToTensor()(Image.open(img)).to('cuda:1').unsqueeze(0)
_, _, h, w = img_t.shape
img_t = img_t[:, :, :128*(h//128), :128*(w//128)]
latent = model(img_t)[1]
_, c, h, w = latent.shape
mins, min_offsets = [], []
total_latents = -1
for d_id in range(1,num_maps+1):
mn, of, tl = _get_mins_from_latent_dictionary(latent, hq_img_repo, "latent_dict_%i.pth" % (d_id), batch_size)
if total_latents != -1:
assert total_latents == tl
else:
total_latents = tl
mins.append(mn)
min_offsets.append(of)
mins = torch.min(torch.stack(mins, dim=-1), dim=-1)
# There's some way to do this in torch, I just can't figure it out..
for i in range(len(mins[1])):
mins[1][i] = mins[1][i] * total_latents + min_offsets[mins[1][i]][i]
min_ids = mins[1]
print("Constructing image map..")
doc_out = '''
<html><body><img id="imgmap" src="source.png" usemap="#map">
<map name="map">%s</map><br>
<button onclick="if(imgmap.src.includes('output.png')){imgmap.src='source.png';}else{imgmap.src='output.png';}">Swap Images</button>
</body></html>
'''
img_map_areas = []
img_out = torch.zeros((1,3,h*16,w*16))
for i, ind in enumerate(tqdm(min_ids)):
u = np.unravel_index(ind.item(), (num_maps*total_latents//(32*32),32,32))
h_, w_ = np.unravel_index(i, (h, w))
img = ToTensor()(Resize((512, 512))(Image.open(img_bank_paths[u[0]])))
t = 16 * u[1]
l = 16 * u[2]
patch = img[:, t:t+16, l:l+16]
img_out[:,:,h_*16:h_*16+16,w_*16:w_*16+16] = patch
# Also save the image with a masked map
mask = torch.full_like(img, fill_value=.3)
mask[:, t:t+16, l:l+16] = 1
masked_img = img * mask
masked_src_img_output_file = os.path.join(output_path, "%i_%i__%i.png" % (t, l, u[0]))
torchvision.utils.save_image(masked_img, masked_src_img_output_file)
# Update the image map areas.
img_map_areas.append('<area shape="rect" coords="%i,%i,%i,%i" href="%s">' % (w_*16,h_*16,w_*16+16,h_*16+16,masked_src_img_output_file))
torchvision.utils.save_image(img_out, os.path.join(output_path, "output.png"))
torchvision.utils.save_image(img_t, os.path.join(output_path, "source.png"))
doc_out = doc_out % ('\n'.join(img_map_areas))
with open(os.path.join(output_path, 'map.html'), 'w') as f:
print(doc_out, file=f)
def explore_latent_results(model):
batch_size = 16
num_workers = 1
output_path = '../../results/byol_spinenet_explore_latents/'
os.makedirs(output_path, exist_ok=True)
dataloader = get_image_folder_dataloader(batch_size, num_workers)
id = 0
for batch in tqdm(dataloader):
hq = batch['hq'].to('cuda:1')
latent = model(hq)[1] # BYOL trainer only trains the '4' output, which is indexed at [1]. Confusing.
# This operation works by computing the distance of every structural index from the center and using that
# as a "heatmap".
b, c, h, w = latent.shape
center = latent[:, :, h//2, w//2].unsqueeze(-1).unsqueeze(-1)
centers = center.repeat(1, 1, h, w)
dist = cosine_similarity(latent, centers).unsqueeze(1)
dist = im_norm(dist)
torchvision.utils.save_image(dist, os.path.join(output_path, "%i.png" % id))
id += 1
if __name__ == '__main__':
pretrained_path = '../../experiments/spinenet49_imgset_byol.pth'
model = SpineNet('49', in_channels=3, use_input_norm=True).to('cuda:1')
model.load_state_dict(torch.load(pretrained_path), strict=True)
model.eval()
with torch.no_grad():
find_similar_latents(model)

View File

@ -40,8 +40,8 @@ if __name__ == '__main__':
break
sampled += 1
im = rgb2ycbcr(train_data['GT'].double())
im_LR = rgb2ycbcr(F.interpolate(train_data['LQ'].double(),
im = rgb2ycbcr(train_data['hq'].double())
im_LR = rgb2ycbcr(F.interpolate(train_data['lq'].double(),
size=im.shape[2:],
mode="bicubic", align_corners=False))
patches_hr = extract_patches_2d(img=im, patch_shape=(patch_size,patch_size), batch_first=True)

View File

@ -16,7 +16,7 @@ import utils.util as util # noqa: E402
def main():
dataset = 'DIV2K_demo' # vimeo90K | REDS | general (e.g., DIV2K, 291) | DIV2K_demo |test
mode = 'GT' # used for vimeo90k and REDS datasets
mode = 'hq' # used for vimeo90k and REDS datasets
# vimeo90k: GT | LR | flow
# REDS: train_sharp, train_sharp_bicubic, train_blur_bicubic, train_blur, train_blur_comp
# train_sharp_flowx4
@ -159,7 +159,7 @@ def vimeo90k(mode):
read_all_imgs = False # whether real all images to memory with multiprocessing
# Set False for use limited memory
BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False
if mode == 'GT':
if mode == 'hq':
img_folder = '../../datasets/vimeo90k/vimeo_septuplet/sequences'
lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb'
txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
@ -204,7 +204,7 @@ def vimeo90k(mode):
keys.append('{}_{}_{}'.format(folder, sub_folder, j + 1))
all_img_list = sorted(all_img_list)
keys = sorted(keys)
if mode == 'GT': # only read the 4th frame for the GT mode
if mode == 'hq': # only read the 4th frame for the GT mode
print('Only keep the 4th frame.')
all_img_list = [v for v in all_img_list if v.endswith('im4.png')]
keys = [v for v in keys if v.endswith('_4')]
@ -255,9 +255,9 @@ def vimeo90k(mode):
#### create meta information
meta_info = {}
if mode == 'GT':
if mode == 'hq':
meta_info['name'] = 'Vimeo90K_train_GT'
elif mode == 'LR':
elif mode == 'lq':
meta_info['name'] = 'Vimeo90K_train_LR'
elif mode == 'flow':
meta_info['name'] = 'Vimeo90K_train_flowx4'

View File

@ -163,7 +163,7 @@ if __name__ == "__main__":
torch.backends.cudnn.benchmark = True
srg_analyze = False
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../experiments/train_exd_imgset_srflow/train_exd_imgset_srflow.yml')
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/train_exd_imgsetext_srflow_frompsnr.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt
@ -180,10 +180,10 @@ if __name__ == "__main__":
gen = model.networks['generator']
gen.eval()
mode = "temperature" # temperature | restore | latent_transfer | feed_through
mode = "restore" # temperature | restore | latent_transfer | feed_through
#imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\val2\\lr\\*"
#imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\*"
imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half\\*lanette*"
imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\*"
#imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half\\*lanette*"
scale = 2
resample_factor = 1 # When != 1, the HR image is upsampled by this factor using a bicubic to get the local latents.
temperature = 1
@ -224,6 +224,11 @@ if __name__ == "__main__":
t = image_2_tensor(img_file).to(model.env['device'])
if resample_factor != 1:
t = F.interpolate(t, scale_factor=resample_factor, mode="bicubic")
# Ensure the input image is a factor of 16.
_, _, h, w = t.shape
h = 16 * (h // 16)
w = 16 * (w // 16)
t = t[:, :, :h, :w]
resample_img = t
# Fetch the latent metrics & latents for each image we are resampling.
@ -255,6 +260,6 @@ if __name__ == "__main__":
for j in range(len(lats)):
path = os.path.join(output_path, "%i_%i" % (im_it, j))
os.makedirs(path, exist_ok=True)
torchvision.utils.save_image(resample_img, os.path.join(path, "%i_orig.jpg" %(im_it)))
torchvision.utils.save_image(resample_img, os.path.join(path, "orig.jpg" %(im_it)))
create_interpolation_video(gen, F.interpolate(resample_img, scale_factor=1/scale, mode="area"),
path, [torch.zeros_like(l) for l in lats[j]], lats[j])

View File

@ -85,8 +85,8 @@ def main():
if dataset == 'REDS' or dataset == 'Vimeo90K':
LQs = data['LQs']
else:
LQ = data['LQ']
GT = data['GT']
LQ = data['lq']
GT = data['hq']
if dataset == 'REDS' or dataset == 'Vimeo90K':
for j in range(LQs.size(1)):

View File

@ -68,6 +68,6 @@ if __name__ == "__main__":
# removed += 1
imname = osp.basename(data['GT_path'][i])
if results[i]-dataset_mean > 1:
torchvision.utils.save_image(data['GT'][i], osp.join(bin_path, imname))
torchvision.utils.save_image(data['hq'][i], osp.join(bin_path, imname))
print("Removed %i/%i images" % (removed, len(test_set)))

View File

@ -66,7 +66,7 @@ if __name__ == "__main__":
model.test()
gen = model.eval_state['gen'][0].to(model.env['device'])
feagen = netF(gen)
feareal = netF(data['GT'].to(model.env['device']))
feareal = netF(data['hq'].to(model.env['device']))
losses = torch.sum(torch.abs(feareal - feagen), dim=(1,2,3))
means.append(torch.mean(losses).item())
#print(sum(means)/len(means), torch.mean(losses), torch.max(losses), torch.min(losses))
@ -76,6 +76,6 @@ if __name__ == "__main__":
removed += 1
#imname = osp.basename(data['GT_path'][i])
#if losses[i] < 25000:
# torchvision.utils.save_image(data['GT'][i], osp.join(bin_path, imname))
# torchvision.utils.save_image(data['hq'][i], osp.join(bin_path, imname))
print("Removed %i/%i images" % (removed, len(test_set)))

View File

@ -41,9 +41,9 @@ def forward_pass(model, output_dir, alteration_suffix=''):
save_img_path = osp.join(output_dir, img_name + '.png')
if need_GT:
fea_loss += model.compute_fea_loss(visuals[i], data['GT'][i])
fea_loss += model.compute_fea_loss(visuals[i], data['hq'][i])
psnr_sr = util.tensor2img(visuals[i])
psnr_gt = util.tensor2img(data['GT'][i])
psnr_gt = util.tensor2img(data['hq'][i])
psnr_loss += util.calculate_psnr(psnr_sr, psnr_gt)
util.save_img(sr_img, save_img_path)

View File

@ -231,13 +231,13 @@ class Trainer:
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
# calculate PSNR
if self.val_compute_psnr:
gt_img = util.tensor2img(visuals['GT'][b]) # uint8
gt_img = util.tensor2img(visuals['hq'][b]) # uint8
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
avg_psnr += util.calculate_psnr(sr_img, gt_img)
# calculate fea loss
if self.val_compute_fea:
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['hq'][b])
# Save SR images for reference
img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step)

View File

@ -3,6 +3,8 @@ import math
import argparse
import random
import logging
import torchvision
from tqdm import tqdm
import torch
@ -231,18 +233,18 @@ class Trainer:
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
# calculate PSNR
if self.val_compute_psnr:
gt_img = util.tensor2img(visuals['GT'][b]) # uint8
gt_img = util.tensor2img(visuals['hq'][b]) # uint8
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
avg_psnr += util.calculate_psnr(sr_img, gt_img)
# calculate fea loss
if self.val_compute_fea:
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['hq'][b])
# Save SR images for reference
img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step)
save_img_path = os.path.join(img_dir, img_base_name)
util.save_img(sr_img, save_img_path)
torchvision.utils.save_image(visuals['rlt'], save_img_path)
avg_psnr = avg_psnr / idx
avg_fea_loss = avg_fea_loss / idx
@ -291,7 +293,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_srflow_frompsnr.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_srflow_bigboi_frompsnr.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()