Directly use dataset keys

This has been a long time coming. Cleans up messy "GT" nomenclature and simplifies ExtensibleTraner.feed_data
This commit is contained in:
James Betker 2020-12-04 20:14:53 -07:00
parent 8a83b1c716
commit 11155aead4
23 changed files with 63 additions and 61 deletions

View File

@ -20,12 +20,15 @@ This directory contains several reference datasets which I have used in building
1. MultiframeDataset - Similar to SingleImageDataset, but infers a temporal relationship between images based on their 1. MultiframeDataset - Similar to SingleImageDataset, but infers a temporal relationship between images based on their
filenames: the last 12 characters before the file extension are assumed to be a frame counter. Images from this filenames: the last 12 characters before the file extension are assumed to be a frame counter. Images from this
dataset are grouped together with a temporal dimension for working with video data. dataset are grouped together with a temporal dimension for working with video data.
1. ImageFolderDataset - Reads raw images from a folder and feeds them into the model. Capable of performing corruptions
on those images like the above.
1. MultiscaleDataset - Reads full images from a directory and builds a tree of images constructed by cropping squares 1. MultiscaleDataset - Reads full images from a directory and builds a tree of images constructed by cropping squares
from the source image and resizing them to the target size recursively until the native resolution is hit. Each from the source image and resizing them to the target size recursively until the native resolution is hit. Each
recursive step decreases the crop size by a factor of 2. recursive step decreases the crop size by a factor of 2.
1. TorchDataset - A wrapper for miscellaneous pytorch datasets (e.g. MNIST, CIFAR, etc) which extracts the images
and reformats them in a way that the DLAS trainer understands.
1. FullImageDataset - An image patch dataset where the patches are dynamically extracted from full-size images. I have 1. FullImageDataset - An image patch dataset where the patches are dynamically extracted from full-size images. I have
generally stopped using this for performance reasons in favor of SingleImageDataset but it is useful for validation generally stopped using this for performance reasons and it should be considered deprecated.
and test so I keep it around.
## Information about the "chunked" format ## Information about the "chunked" format

View File

@ -323,7 +323,7 @@ class FullImageDataset(data.Dataset):
gt_fullsize_ref = torch.cat([gt_fullsize_ref, gt_mask], dim=0) gt_fullsize_ref = torch.cat([gt_fullsize_ref, gt_mask], dim=0)
lq_fullsize_ref = torch.cat([lq_fullsize_ref, lq_mask], dim=0) lq_fullsize_ref = torch.cat([lq_fullsize_ref, lq_mask], dim=0)
d = {'LQ': img_LQ, 'GT': img_GT, 'gt_fullsize_ref': gt_fullsize_ref, 'lq_fullsize_ref': lq_fullsize_ref, d = {'lq': img_LQ, 'hq': img_GT, 'gt_fullsize_ref': gt_fullsize_ref, 'lq_fullsize_ref': lq_fullsize_ref,
'lq_center': lq_center, 'gt_center': gt_center, 'lq_center': lq_center, 'gt_center': gt_center,
'LQ_path': LQ_path, 'GT_path': full_path} 'LQ_path': LQ_path, 'GT_path': full_path}
return d return d

View File

@ -95,7 +95,7 @@ class ImageFolderDataset:
hq = torch.from_numpy(np.ascontiguousarray(np.transpose(hs[0], (2, 0, 1)))).float() hq = torch.from_numpy(np.ascontiguousarray(np.transpose(hs[0], (2, 0, 1)))).float()
lq = torch.from_numpy(np.ascontiguousarray(np.transpose(ls[0], (2, 0, 1)))).float() lq = torch.from_numpy(np.ascontiguousarray(np.transpose(ls[0], (2, 0, 1)))).float()
return {'LQ': lq, 'GT': hq, 'LQ_path': self.image_paths[item], 'GT_path': self.image_paths[item]} return {'lq': lq, 'hq': hq, 'LQ_path': self.image_paths[item], 'GT_path': self.image_paths[item]}
if __name__ == '__main__': if __name__ == '__main__':
@ -118,7 +118,7 @@ if __name__ == '__main__':
for i in range(0, len(ds)): for i in range(0, len(ds)):
o = ds[random.randint(0, len(ds))] o = ds[random.randint(0, len(ds))]
#for k, v in o.items(): #for k, v in o.items():
k = 'LQ' k = 'lq'
v = o[k] v = o[k]
#if 'LQ' in k and 'path' not in k and 'center' not in k: #if 'LQ' in k and 'path' not in k and 'center' not in k:
#if 'full' in k: #if 'full' in k:

View File

@ -57,7 +57,7 @@ class MultiFrameDataset(BaseUnsupervisedImageDataset):
lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).unsqueeze(dim=1) lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).unsqueeze(dim=1)
lq_ref = torch.cat([lq_ref, lq_mask], dim=1) lq_ref = torch.cat([lq_ref, lq_mask], dim=1)
return {'GT_path': path, 'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, return {'GT_path': path, 'lq': lq, 'hq': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)} 'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)}
@ -83,7 +83,7 @@ if __name__ == '__main__':
batch = None batch = None
for i in range(len(ds)): for i in range(len(ds)):
import random import random
k = 'LQ' k = 'lq'
element = ds[random.randint(0,len(ds))] element = ds[random.randint(0,len(ds))]
base_file = osp.basename(element["GT_path"]) base_file = osp.basename(element["GT_path"])
o = element[k].unsqueeze(0) o = element[k].unsqueeze(0)

View File

@ -87,7 +87,7 @@ class MultiScaleDataset(data.Dataset):
patches_lq = [torch.nn.functional.interpolate(p.unsqueeze(0), scale_factor=1/self.scale, mode='area').squeeze() for p in patches_hq_corrupted] patches_lq = [torch.nn.functional.interpolate(p.unsqueeze(0), scale_factor=1/self.scale, mode='area').squeeze() for p in patches_hq_corrupted]
patches_lq = torch.stack(patches_lq, dim=0) patches_lq = torch.stack(patches_lq, dim=0)
d = {'LQ': patches_lq, 'GT': patches_hq, 'GT_path': full_path} d = {'lq': patches_lq, 'hq': patches_hq, 'GT_path': full_path}
return d return d
def __len__(self): def __len__(self):

View File

@ -42,7 +42,7 @@ class PairedFrameDataset(BaseUnsupervisedImageDataset):
lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).squeeze().unsqueeze(dim=1) lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).squeeze().unsqueeze(dim=1)
lq_ref = torch.cat([lq_ref, lq_mask], dim=1) lq_ref = torch.cat([lq_ref, lq_mask], dim=1)
return {'GT_path': path, 'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, return {'GT_path': path, 'lq': lq, 'hq': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)} 'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)}
@ -68,7 +68,7 @@ if __name__ == '__main__':
batch = None batch = None
for i in range(len(ds)): for i in range(len(ds)):
import random import random
k = 'LQ' k = 'lq'
element = ds[random.randint(0,len(ds))] element = ds[random.randint(0,len(ds))]
base_file = osp.basename(element["GT_path"]) base_file = osp.basename(element["GT_path"])
o = element[k].unsqueeze(0) o = element[k].unsqueeze(0)

View File

@ -36,7 +36,7 @@ class SingleImageDataset(BaseUnsupervisedImageDataset):
lq_mask = torch.from_numpy(np.ascontiguousarray(lms[0])).unsqueeze(dim=0) lq_mask = torch.from_numpy(np.ascontiguousarray(lms[0])).unsqueeze(dim=0)
lq_ref = torch.cat([lq_ref, lq_mask], dim=0) lq_ref = torch.cat([lq_ref, lq_mask], dim=0)
return {'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, return {'lq': lq, 'hq': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
'lq_center': torch.tensor(lcs[0], dtype=torch.long), 'gt_center': torch.tensor(hcs[0], dtype=torch.long), 'lq_center': torch.tensor(lcs[0], dtype=torch.long), 'gt_center': torch.tensor(hcs[0], dtype=torch.long),
'LQ_path': path, 'GT_path': path} 'LQ_path': path, 'GT_path': path}
@ -62,7 +62,7 @@ if __name__ == '__main__':
for i in range(0, len(ds)): for i in range(0, len(ds)):
o = ds[random.randint(0, len(ds))] o = ds[random.randint(0, len(ds))]
#for k, v in o.items(): #for k, v in o.items():
k = 'LQ' k = 'lq'
v = o[k] v = o[k]
#if 'LQ' in k and 'path' not in k and 'center' not in k: #if 'LQ' in k and 'path' not in k and 'center' not in k:
#if 'full' in k: #if 'full' in k:

View File

@ -98,4 +98,4 @@ class Stylegan2Dataset(data.Dataset):
path = self.paths[index] path = self.paths[index]
img = Image.open(path) img = Image.open(path)
img = self.transform(img) img = self.transform(img)
return {'LQ': img, 'GT': img, 'GT_path': str(path)} return {'lq': img, 'hq': img, 'GT_path': str(path)}

View File

@ -24,7 +24,7 @@ class TorchDataset(Dataset):
def __getitem__(self, item): def __getitem__(self, item):
underlying_item = self.dataset[item][0] underlying_item = self.dataset[item][0]
return {'LQ': underlying_item, 'GT': underlying_item, return {'lq': underlying_item, 'hq': underlying_item,
'LQ_path': str(item), 'GT_path': str(item)} 'LQ_path': str(item), 'GT_path': str(item)}
def __len__(self): def __len__(self):

View File

@ -4,7 +4,6 @@ import os
import torch import torch
from torch.nn.parallel import DataParallel from torch.nn.parallel import DataParallel
import torch.nn as nn import torch.nn as nn
from apex.parallel import DistributedDataParallel
import models.lr_scheduler as lr_scheduler import models.lr_scheduler as lr_scheduler
import models.networks as networks import models.networks as networks
@ -106,6 +105,8 @@ class ExtensibleTrainer(BaseModel):
all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()] all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
for anet in all_networks: for anet in all_networks:
if opt['dist']: if opt['dist']:
# Use Apex to enable delay_allreduce, which is compatible with gradient checkpointing.
from apex.parallel import DistributedDataParallel
dnet = DistributedDataParallel(anet, delay_allreduce=True) dnet = DistributedDataParallel(anet, delay_allreduce=True)
else: else:
dnet = DataParallel(anet, device_ids=opt['gpu_ids']) dnet = DataParallel(anet, device_ids=opt['gpu_ids'])
@ -160,18 +161,9 @@ class ExtensibleTrainer(BaseModel):
o.zero_grad() o.zero_grad()
torch.cuda.empty_cache() torch.cuda.empty_cache()
self.lq = [t.to(self.device) for t in torch.chunk(data['LQ'], chunks=self.batch_factor, dim=0)] self.dstate = {}
if need_GT:
self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.batch_factor, dim=0)]
input_ref = data['ref'] if 'ref' in data.keys() else data['GT']
self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.batch_factor, dim=0)]
else:
self.hq = self.lq
self.ref = self.lq
self.dstate = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref}
for k, v in data.items(): for k, v in data.items():
if k not in ['LQ', 'ref', 'GT'] and isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=self.batch_factor, dim=0)] self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=self.batch_factor, dim=0)]
def optimize_parameters(self, step): def optimize_parameters(self, step):
@ -328,8 +320,8 @@ class ExtensibleTrainer(BaseModel):
def get_current_visuals(self, need_GT=True): def get_current_visuals(self, need_GT=True):
# Conforms to an archaic format from MMSR. # Conforms to an archaic format from MMSR.
return {'LQ': self.eval_state['lq'][0].float().cpu(), return {'lq': self.eval_state['lq'][0].float().cpu(),
'GT': self.eval_state['hq'][0].float().cpu(), 'hq': self.eval_state['hq'][0].float().cpu(),
'rlt': self.eval_state[self.opt['eval']['output_state']][0].float().cpu()} 'rlt': self.eval_state[self.opt['eval']['output_state']][0].float().cpu()}
def print_network(self): def print_network(self):

View File

@ -30,8 +30,8 @@ class FlowGaussianNll(evaluator.Evaluator):
print("Evaluating FlowGaussianNll..") print("Evaluating FlowGaussianNll..")
for batch in tqdm(self.dataloader): for batch in tqdm(self.dataloader):
dev = self.env['device'] dev = self.env['device']
z, _, _ = self.model(gt=batch['GT'].to(dev), z, _, _ = self.model(gt=batch['hq'].to(dev),
lr=batch['LQ'].to(dev), lr=batch['lq'].to(dev),
epses=[], epses=[],
reverse=False, reverse=False,
add_gt_noise=False) add_gt_noise=False)

View File

@ -39,7 +39,7 @@ class SrStyleTransferEvaluator(evaluator.Evaluator):
counter = 0 counter = 0
for batch in self.sampler: for batch in self.sampler:
noise = torch.FloatTensor(self.batch_sz, 3, self.im_sz, self.im_sz).uniform_(0., 1.).to(self.env['device']) noise = torch.FloatTensor(self.batch_sz, 3, self.im_sz, self.im_sz).uniform_(0., 1.).to(self.env['device'])
batch_hq = [e['GT'] for e in batch] batch_hq = [e['hq'] for e in batch]
batch_hq = torch.stack(batch_hq, dim=0).to(self.env['device']) batch_hq = torch.stack(batch_hq, dim=0).to(self.env['device'])
resized_batch = torch.nn.functional.interpolate(batch_hq, scale_factor=1/self.scale, mode="area") resized_batch = torch.nn.functional.interpolate(batch_hq, scale_factor=1/self.scale, mode="area")
embedding = embedding_generator(resized_batch) embedding = embedding_generator(resized_batch)

View File

@ -66,9 +66,9 @@ class FeatureModel(BaseModel):
self.log_dict = OrderedDict() self.log_dict = OrderedDict()
def feed_data(self, data, need_GT=True): def feed_data(self, data, need_GT=True):
self.var_L = data['LQ'].to(self.device) # LQ self.var_L = data['lq'].to(self.device) # LQ
if need_GT: if need_GT:
self.real_H = data['GT'].to(self.device) # GT self.real_H = data['hq'].to(self.device) # GT
def optimize_parameters(self, step): def optimize_parameters(self, step):
self.optimizer_G.zero_grad() self.optimizer_G.zero_grad()

View File

@ -88,7 +88,7 @@ class FfmpegBackedVideoDataset(data.Dataset):
img_LQ = lq_template img_LQ = lq_template
ref = ref_template ref = ref_template
return {'LQ': img_LQ, 'lq_fullsize_ref': ref, return {'lq': img_LQ, 'lq_fullsize_ref': ref,
'lq_center': torch.tensor([img_LQ.shape[1] // 2, img_LQ.shape[2] // 2], dtype=torch.long) } 'lq_center': torch.tensor([img_LQ.shape[1] // 2, img_LQ.shape[2] // 2], dtype=torch.long) }
def __len__(self): def __len__(self):
@ -159,8 +159,8 @@ if __name__ == "__main__":
need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
if recurrent_mode and first_frame: if recurrent_mode and first_frame:
b, c, h, w = data['LQ'].shape b, c, h, w = data['lq'].shape
recurrent_entry = torch.zeros((b,c,h*scale,w*scale), device=data['LQ'].device) recurrent_entry = torch.zeros((b,c,h*scale,w*scale), device=data['lq'].device)
# Optionally swap out the 'generator' for the first frame to create a better image that the recurrent generator works off of. # Optionally swap out the 'generator' for the first frame to create a better image that the recurrent generator works off of.
if 'recurrent_hr_generator' in opt.keys(): if 'recurrent_hr_generator' in opt.keys():
recurrent_gen = model.env['generators']['generator'] recurrent_gen = model.env['generators']['generator']

View File

@ -40,8 +40,8 @@ if __name__ == '__main__':
break break
sampled += 1 sampled += 1
im = rgb2ycbcr(train_data['GT'].double()) im = rgb2ycbcr(train_data['hq'].double())
im_LR = rgb2ycbcr(F.interpolate(train_data['LQ'].double(), im_LR = rgb2ycbcr(F.interpolate(train_data['lq'].double(),
size=im.shape[2:], size=im.shape[2:],
mode="bicubic", align_corners=False)) mode="bicubic", align_corners=False))
patches_hr = extract_patches_2d(img=im, patch_shape=(patch_size,patch_size), batch_first=True) patches_hr = extract_patches_2d(img=im, patch_shape=(patch_size,patch_size), batch_first=True)

View File

@ -16,7 +16,7 @@ import utils.util as util # noqa: E402
def main(): def main():
dataset = 'DIV2K_demo' # vimeo90K | REDS | general (e.g., DIV2K, 291) | DIV2K_demo |test dataset = 'DIV2K_demo' # vimeo90K | REDS | general (e.g., DIV2K, 291) | DIV2K_demo |test
mode = 'GT' # used for vimeo90k and REDS datasets mode = 'hq' # used for vimeo90k and REDS datasets
# vimeo90k: GT | LR | flow # vimeo90k: GT | LR | flow
# REDS: train_sharp, train_sharp_bicubic, train_blur_bicubic, train_blur, train_blur_comp # REDS: train_sharp, train_sharp_bicubic, train_blur_bicubic, train_blur, train_blur_comp
# train_sharp_flowx4 # train_sharp_flowx4
@ -159,7 +159,7 @@ def vimeo90k(mode):
read_all_imgs = False # whether real all images to memory with multiprocessing read_all_imgs = False # whether real all images to memory with multiprocessing
# Set False for use limited memory # Set False for use limited memory
BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False
if mode == 'GT': if mode == 'hq':
img_folder = '../../datasets/vimeo90k/vimeo_septuplet/sequences' img_folder = '../../datasets/vimeo90k/vimeo_septuplet/sequences'
lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb' lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb'
txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt' txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
@ -204,7 +204,7 @@ def vimeo90k(mode):
keys.append('{}_{}_{}'.format(folder, sub_folder, j + 1)) keys.append('{}_{}_{}'.format(folder, sub_folder, j + 1))
all_img_list = sorted(all_img_list) all_img_list = sorted(all_img_list)
keys = sorted(keys) keys = sorted(keys)
if mode == 'GT': # only read the 4th frame for the GT mode if mode == 'hq': # only read the 4th frame for the GT mode
print('Only keep the 4th frame.') print('Only keep the 4th frame.')
all_img_list = [v for v in all_img_list if v.endswith('im4.png')] all_img_list = [v for v in all_img_list if v.endswith('im4.png')]
keys = [v for v in keys if v.endswith('_4')] keys = [v for v in keys if v.endswith('_4')]
@ -255,9 +255,9 @@ def vimeo90k(mode):
#### create meta information #### create meta information
meta_info = {} meta_info = {}
if mode == 'GT': if mode == 'hq':
meta_info['name'] = 'Vimeo90K_train_GT' meta_info['name'] = 'Vimeo90K_train_GT'
elif mode == 'LR': elif mode == 'lq':
meta_info['name'] = 'Vimeo90K_train_LR' meta_info['name'] = 'Vimeo90K_train_LR'
elif mode == 'flow': elif mode == 'flow':
meta_info['name'] = 'Vimeo90K_train_flowx4' meta_info['name'] = 'Vimeo90K_train_flowx4'

View File

@ -163,7 +163,7 @@ if __name__ == "__main__":
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
srg_analyze = False srg_analyze = False
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../experiments/train_exd_imgset_srflow/train_exd_imgset_srflow.yml') parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/train_exd_imgsetext_srflow_frompsnr.yml')
opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt) opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt utils.util.loaded_options = opt
@ -180,10 +180,10 @@ if __name__ == "__main__":
gen = model.networks['generator'] gen = model.networks['generator']
gen.eval() gen.eval()
mode = "temperature" # temperature | restore | latent_transfer | feed_through mode = "restore" # temperature | restore | latent_transfer | feed_through
#imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\val2\\lr\\*" #imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\val2\\lr\\*"
#imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\*" imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\*"
imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half\\*lanette*" #imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half\\*lanette*"
scale = 2 scale = 2
resample_factor = 1 # When != 1, the HR image is upsampled by this factor using a bicubic to get the local latents. resample_factor = 1 # When != 1, the HR image is upsampled by this factor using a bicubic to get the local latents.
temperature = 1 temperature = 1
@ -224,6 +224,11 @@ if __name__ == "__main__":
t = image_2_tensor(img_file).to(model.env['device']) t = image_2_tensor(img_file).to(model.env['device'])
if resample_factor != 1: if resample_factor != 1:
t = F.interpolate(t, scale_factor=resample_factor, mode="bicubic") t = F.interpolate(t, scale_factor=resample_factor, mode="bicubic")
# Ensure the input image is a factor of 16.
_, _, h, w = t.shape
h = 16 * (h // 16)
w = 16 * (w // 16)
t = t[:, :, :h, :w]
resample_img = t resample_img = t
# Fetch the latent metrics & latents for each image we are resampling. # Fetch the latent metrics & latents for each image we are resampling.
@ -255,6 +260,6 @@ if __name__ == "__main__":
for j in range(len(lats)): for j in range(len(lats)):
path = os.path.join(output_path, "%i_%i" % (im_it, j)) path = os.path.join(output_path, "%i_%i" % (im_it, j))
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
torchvision.utils.save_image(resample_img, os.path.join(path, "%i_orig.jpg" %(im_it))) torchvision.utils.save_image(resample_img, os.path.join(path, "orig.jpg" %(im_it)))
create_interpolation_video(gen, F.interpolate(resample_img, scale_factor=1/scale, mode="area"), create_interpolation_video(gen, F.interpolate(resample_img, scale_factor=1/scale, mode="area"),
path, [torch.zeros_like(l) for l in lats[j]], lats[j]) path, [torch.zeros_like(l) for l in lats[j]], lats[j])

View File

@ -85,8 +85,8 @@ def main():
if dataset == 'REDS' or dataset == 'Vimeo90K': if dataset == 'REDS' or dataset == 'Vimeo90K':
LQs = data['LQs'] LQs = data['LQs']
else: else:
LQ = data['LQ'] LQ = data['lq']
GT = data['GT'] GT = data['hq']
if dataset == 'REDS' or dataset == 'Vimeo90K': if dataset == 'REDS' or dataset == 'Vimeo90K':
for j in range(LQs.size(1)): for j in range(LQs.size(1)):

View File

@ -68,6 +68,6 @@ if __name__ == "__main__":
# removed += 1 # removed += 1
imname = osp.basename(data['GT_path'][i]) imname = osp.basename(data['GT_path'][i])
if results[i]-dataset_mean > 1: if results[i]-dataset_mean > 1:
torchvision.utils.save_image(data['GT'][i], osp.join(bin_path, imname)) torchvision.utils.save_image(data['hq'][i], osp.join(bin_path, imname))
print("Removed %i/%i images" % (removed, len(test_set))) print("Removed %i/%i images" % (removed, len(test_set)))

View File

@ -66,7 +66,7 @@ if __name__ == "__main__":
model.test() model.test()
gen = model.eval_state['gen'][0].to(model.env['device']) gen = model.eval_state['gen'][0].to(model.env['device'])
feagen = netF(gen) feagen = netF(gen)
feareal = netF(data['GT'].to(model.env['device'])) feareal = netF(data['hq'].to(model.env['device']))
losses = torch.sum(torch.abs(feareal - feagen), dim=(1,2,3)) losses = torch.sum(torch.abs(feareal - feagen), dim=(1,2,3))
means.append(torch.mean(losses).item()) means.append(torch.mean(losses).item())
#print(sum(means)/len(means), torch.mean(losses), torch.max(losses), torch.min(losses)) #print(sum(means)/len(means), torch.mean(losses), torch.max(losses), torch.min(losses))
@ -76,6 +76,6 @@ if __name__ == "__main__":
removed += 1 removed += 1
#imname = osp.basename(data['GT_path'][i]) #imname = osp.basename(data['GT_path'][i])
#if losses[i] < 25000: #if losses[i] < 25000:
# torchvision.utils.save_image(data['GT'][i], osp.join(bin_path, imname)) # torchvision.utils.save_image(data['hq'][i], osp.join(bin_path, imname))
print("Removed %i/%i images" % (removed, len(test_set))) print("Removed %i/%i images" % (removed, len(test_set)))

View File

@ -41,9 +41,9 @@ def forward_pass(model, output_dir, alteration_suffix=''):
save_img_path = osp.join(output_dir, img_name + '.png') save_img_path = osp.join(output_dir, img_name + '.png')
if need_GT: if need_GT:
fea_loss += model.compute_fea_loss(visuals[i], data['GT'][i]) fea_loss += model.compute_fea_loss(visuals[i], data['hq'][i])
psnr_sr = util.tensor2img(visuals[i]) psnr_sr = util.tensor2img(visuals[i])
psnr_gt = util.tensor2img(data['GT'][i]) psnr_gt = util.tensor2img(data['hq'][i])
psnr_loss += util.calculate_psnr(psnr_sr, psnr_gt) psnr_loss += util.calculate_psnr(psnr_sr, psnr_gt)
util.save_img(sr_img, save_img_path) util.save_img(sr_img, save_img_path)

View File

@ -231,13 +231,13 @@ class Trainer:
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8 sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
# calculate PSNR # calculate PSNR
if self.val_compute_psnr: if self.val_compute_psnr:
gt_img = util.tensor2img(visuals['GT'][b]) # uint8 gt_img = util.tensor2img(visuals['hq'][b]) # uint8
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
avg_psnr += util.calculate_psnr(sr_img, gt_img) avg_psnr += util.calculate_psnr(sr_img, gt_img)
# calculate fea loss # calculate fea loss
if self.val_compute_fea: if self.val_compute_fea:
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b]) avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['hq'][b])
# Save SR images for reference # Save SR images for reference
img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step) img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step)

View File

@ -3,6 +3,8 @@ import math
import argparse import argparse
import random import random
import logging import logging
import torchvision
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@ -231,18 +233,18 @@ class Trainer:
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8 sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
# calculate PSNR # calculate PSNR
if self.val_compute_psnr: if self.val_compute_psnr:
gt_img = util.tensor2img(visuals['GT'][b]) # uint8 gt_img = util.tensor2img(visuals['hq'][b]) # uint8
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
avg_psnr += util.calculate_psnr(sr_img, gt_img) avg_psnr += util.calculate_psnr(sr_img, gt_img)
# calculate fea loss # calculate fea loss
if self.val_compute_fea: if self.val_compute_fea:
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b]) avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['hq'][b])
# Save SR images for reference # Save SR images for reference
img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step) img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step)
save_img_path = os.path.join(img_dir, img_base_name) save_img_path = os.path.join(img_dir, img_base_name)
util.save_img(sr_img, save_img_path) torchvision.utils.save_image(visuals['rlt'], save_img_path)
avg_psnr = avg_psnr / idx avg_psnr = avg_psnr / idx
avg_fea_loss = avg_fea_loss / idx avg_fea_loss = avg_fea_loss / idx
@ -291,7 +293,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_srflow_frompsnr.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_srflow_bigboi_frompsnr.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()