Directly use dataset keys

This has been a long time coming. Cleans up messy "GT" nomenclature and simplifies ExtensibleTraner.feed_data
This commit is contained in:
James Betker 2020-12-04 20:14:53 -07:00
parent 8a83b1c716
commit 11155aead4
23 changed files with 63 additions and 61 deletions

View File

@ -20,12 +20,15 @@ This directory contains several reference datasets which I have used in building
1. MultiframeDataset - Similar to SingleImageDataset, but infers a temporal relationship between images based on their
filenames: the last 12 characters before the file extension are assumed to be a frame counter. Images from this
dataset are grouped together with a temporal dimension for working with video data.
1. ImageFolderDataset - Reads raw images from a folder and feeds them into the model. Capable of performing corruptions
on those images like the above.
1. MultiscaleDataset - Reads full images from a directory and builds a tree of images constructed by cropping squares
from the source image and resizing them to the target size recursively until the native resolution is hit. Each
recursive step decreases the crop size by a factor of 2.
1. TorchDataset - A wrapper for miscellaneous pytorch datasets (e.g. MNIST, CIFAR, etc) which extracts the images
and reformats them in a way that the DLAS trainer understands.
1. FullImageDataset - An image patch dataset where the patches are dynamically extracted from full-size images. I have
generally stopped using this for performance reasons in favor of SingleImageDataset but it is useful for validation
and test so I keep it around.
generally stopped using this for performance reasons and it should be considered deprecated.
## Information about the "chunked" format

View File

@ -323,7 +323,7 @@ class FullImageDataset(data.Dataset):
gt_fullsize_ref = torch.cat([gt_fullsize_ref, gt_mask], dim=0)
lq_fullsize_ref = torch.cat([lq_fullsize_ref, lq_mask], dim=0)
d = {'LQ': img_LQ, 'GT': img_GT, 'gt_fullsize_ref': gt_fullsize_ref, 'lq_fullsize_ref': lq_fullsize_ref,
d = {'lq': img_LQ, 'hq': img_GT, 'gt_fullsize_ref': gt_fullsize_ref, 'lq_fullsize_ref': lq_fullsize_ref,
'lq_center': lq_center, 'gt_center': gt_center,
'LQ_path': LQ_path, 'GT_path': full_path}
return d

View File

@ -95,7 +95,7 @@ class ImageFolderDataset:
hq = torch.from_numpy(np.ascontiguousarray(np.transpose(hs[0], (2, 0, 1)))).float()
lq = torch.from_numpy(np.ascontiguousarray(np.transpose(ls[0], (2, 0, 1)))).float()
return {'LQ': lq, 'GT': hq, 'LQ_path': self.image_paths[item], 'GT_path': self.image_paths[item]}
return {'lq': lq, 'hq': hq, 'LQ_path': self.image_paths[item], 'GT_path': self.image_paths[item]}
if __name__ == '__main__':
@ -118,7 +118,7 @@ if __name__ == '__main__':
for i in range(0, len(ds)):
o = ds[random.randint(0, len(ds))]
#for k, v in o.items():
k = 'LQ'
k = 'lq'
v = o[k]
#if 'LQ' in k and 'path' not in k and 'center' not in k:
#if 'full' in k:

View File

@ -57,7 +57,7 @@ class MultiFrameDataset(BaseUnsupervisedImageDataset):
lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).unsqueeze(dim=1)
lq_ref = torch.cat([lq_ref, lq_mask], dim=1)
return {'GT_path': path, 'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
return {'GT_path': path, 'lq': lq, 'hq': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)}
@ -83,7 +83,7 @@ if __name__ == '__main__':
batch = None
for i in range(len(ds)):
import random
k = 'LQ'
k = 'lq'
element = ds[random.randint(0,len(ds))]
base_file = osp.basename(element["GT_path"])
o = element[k].unsqueeze(0)

View File

@ -87,7 +87,7 @@ class MultiScaleDataset(data.Dataset):
patches_lq = [torch.nn.functional.interpolate(p.unsqueeze(0), scale_factor=1/self.scale, mode='area').squeeze() for p in patches_hq_corrupted]
patches_lq = torch.stack(patches_lq, dim=0)
d = {'LQ': patches_lq, 'GT': patches_hq, 'GT_path': full_path}
d = {'lq': patches_lq, 'hq': patches_hq, 'GT_path': full_path}
return d
def __len__(self):

View File

@ -42,7 +42,7 @@ class PairedFrameDataset(BaseUnsupervisedImageDataset):
lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).squeeze().unsqueeze(dim=1)
lq_ref = torch.cat([lq_ref, lq_mask], dim=1)
return {'GT_path': path, 'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
return {'GT_path': path, 'lq': lq, 'hq': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)}
@ -68,7 +68,7 @@ if __name__ == '__main__':
batch = None
for i in range(len(ds)):
import random
k = 'LQ'
k = 'lq'
element = ds[random.randint(0,len(ds))]
base_file = osp.basename(element["GT_path"])
o = element[k].unsqueeze(0)

View File

@ -36,7 +36,7 @@ class SingleImageDataset(BaseUnsupervisedImageDataset):
lq_mask = torch.from_numpy(np.ascontiguousarray(lms[0])).unsqueeze(dim=0)
lq_ref = torch.cat([lq_ref, lq_mask], dim=0)
return {'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
return {'lq': lq, 'hq': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
'lq_center': torch.tensor(lcs[0], dtype=torch.long), 'gt_center': torch.tensor(hcs[0], dtype=torch.long),
'LQ_path': path, 'GT_path': path}
@ -62,7 +62,7 @@ if __name__ == '__main__':
for i in range(0, len(ds)):
o = ds[random.randint(0, len(ds))]
#for k, v in o.items():
k = 'LQ'
k = 'lq'
v = o[k]
#if 'LQ' in k and 'path' not in k and 'center' not in k:
#if 'full' in k:

View File

@ -98,4 +98,4 @@ class Stylegan2Dataset(data.Dataset):
path = self.paths[index]
img = Image.open(path)
img = self.transform(img)
return {'LQ': img, 'GT': img, 'GT_path': str(path)}
return {'lq': img, 'hq': img, 'GT_path': str(path)}

View File

@ -24,7 +24,7 @@ class TorchDataset(Dataset):
def __getitem__(self, item):
underlying_item = self.dataset[item][0]
return {'LQ': underlying_item, 'GT': underlying_item,
return {'lq': underlying_item, 'hq': underlying_item,
'LQ_path': str(item), 'GT_path': str(item)}
def __len__(self):

View File

@ -4,7 +4,6 @@ import os
import torch
from torch.nn.parallel import DataParallel
import torch.nn as nn
from apex.parallel import DistributedDataParallel
import models.lr_scheduler as lr_scheduler
import models.networks as networks
@ -106,6 +105,8 @@ class ExtensibleTrainer(BaseModel):
all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
for anet in all_networks:
if opt['dist']:
# Use Apex to enable delay_allreduce, which is compatible with gradient checkpointing.
from apex.parallel import DistributedDataParallel
dnet = DistributedDataParallel(anet, delay_allreduce=True)
else:
dnet = DataParallel(anet, device_ids=opt['gpu_ids'])
@ -160,18 +161,9 @@ class ExtensibleTrainer(BaseModel):
o.zero_grad()
torch.cuda.empty_cache()
self.lq = [t.to(self.device) for t in torch.chunk(data['LQ'], chunks=self.batch_factor, dim=0)]
if need_GT:
self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.batch_factor, dim=0)]
input_ref = data['ref'] if 'ref' in data.keys() else data['GT']
self.ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.batch_factor, dim=0)]
else:
self.hq = self.lq
self.ref = self.lq
self.dstate = {'lq': self.lq, 'hq': self.hq, 'ref': self.ref}
self.dstate = {}
for k, v in data.items():
if k not in ['LQ', 'ref', 'GT'] and isinstance(v, torch.Tensor):
if isinstance(v, torch.Tensor):
self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=self.batch_factor, dim=0)]
def optimize_parameters(self, step):
@ -328,8 +320,8 @@ class ExtensibleTrainer(BaseModel):
def get_current_visuals(self, need_GT=True):
# Conforms to an archaic format from MMSR.
return {'LQ': self.eval_state['lq'][0].float().cpu(),
'GT': self.eval_state['hq'][0].float().cpu(),
return {'lq': self.eval_state['lq'][0].float().cpu(),
'hq': self.eval_state['hq'][0].float().cpu(),
'rlt': self.eval_state[self.opt['eval']['output_state']][0].float().cpu()}
def print_network(self):

View File

@ -30,8 +30,8 @@ class FlowGaussianNll(evaluator.Evaluator):
print("Evaluating FlowGaussianNll..")
for batch in tqdm(self.dataloader):
dev = self.env['device']
z, _, _ = self.model(gt=batch['GT'].to(dev),
lr=batch['LQ'].to(dev),
z, _, _ = self.model(gt=batch['hq'].to(dev),
lr=batch['lq'].to(dev),
epses=[],
reverse=False,
add_gt_noise=False)

View File

@ -39,7 +39,7 @@ class SrStyleTransferEvaluator(evaluator.Evaluator):
counter = 0
for batch in self.sampler:
noise = torch.FloatTensor(self.batch_sz, 3, self.im_sz, self.im_sz).uniform_(0., 1.).to(self.env['device'])
batch_hq = [e['GT'] for e in batch]
batch_hq = [e['hq'] for e in batch]
batch_hq = torch.stack(batch_hq, dim=0).to(self.env['device'])
resized_batch = torch.nn.functional.interpolate(batch_hq, scale_factor=1/self.scale, mode="area")
embedding = embedding_generator(resized_batch)

View File

@ -66,9 +66,9 @@ class FeatureModel(BaseModel):
self.log_dict = OrderedDict()
def feed_data(self, data, need_GT=True):
self.var_L = data['LQ'].to(self.device) # LQ
self.var_L = data['lq'].to(self.device) # LQ
if need_GT:
self.real_H = data['GT'].to(self.device) # GT
self.real_H = data['hq'].to(self.device) # GT
def optimize_parameters(self, step):
self.optimizer_G.zero_grad()

View File

@ -88,7 +88,7 @@ class FfmpegBackedVideoDataset(data.Dataset):
img_LQ = lq_template
ref = ref_template
return {'LQ': img_LQ, 'lq_fullsize_ref': ref,
return {'lq': img_LQ, 'lq_fullsize_ref': ref,
'lq_center': torch.tensor([img_LQ.shape[1] // 2, img_LQ.shape[2] // 2], dtype=torch.long) }
def __len__(self):
@ -159,8 +159,8 @@ if __name__ == "__main__":
need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
if recurrent_mode and first_frame:
b, c, h, w = data['LQ'].shape
recurrent_entry = torch.zeros((b,c,h*scale,w*scale), device=data['LQ'].device)
b, c, h, w = data['lq'].shape
recurrent_entry = torch.zeros((b,c,h*scale,w*scale), device=data['lq'].device)
# Optionally swap out the 'generator' for the first frame to create a better image that the recurrent generator works off of.
if 'recurrent_hr_generator' in opt.keys():
recurrent_gen = model.env['generators']['generator']

View File

@ -40,8 +40,8 @@ if __name__ == '__main__':
break
sampled += 1
im = rgb2ycbcr(train_data['GT'].double())
im_LR = rgb2ycbcr(F.interpolate(train_data['LQ'].double(),
im = rgb2ycbcr(train_data['hq'].double())
im_LR = rgb2ycbcr(F.interpolate(train_data['lq'].double(),
size=im.shape[2:],
mode="bicubic", align_corners=False))
patches_hr = extract_patches_2d(img=im, patch_shape=(patch_size,patch_size), batch_first=True)

View File

@ -16,7 +16,7 @@ import utils.util as util # noqa: E402
def main():
dataset = 'DIV2K_demo' # vimeo90K | REDS | general (e.g., DIV2K, 291) | DIV2K_demo |test
mode = 'GT' # used for vimeo90k and REDS datasets
mode = 'hq' # used for vimeo90k and REDS datasets
# vimeo90k: GT | LR | flow
# REDS: train_sharp, train_sharp_bicubic, train_blur_bicubic, train_blur, train_blur_comp
# train_sharp_flowx4
@ -159,7 +159,7 @@ def vimeo90k(mode):
read_all_imgs = False # whether real all images to memory with multiprocessing
# Set False for use limited memory
BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False
if mode == 'GT':
if mode == 'hq':
img_folder = '../../datasets/vimeo90k/vimeo_septuplet/sequences'
lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb'
txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
@ -204,7 +204,7 @@ def vimeo90k(mode):
keys.append('{}_{}_{}'.format(folder, sub_folder, j + 1))
all_img_list = sorted(all_img_list)
keys = sorted(keys)
if mode == 'GT': # only read the 4th frame for the GT mode
if mode == 'hq': # only read the 4th frame for the GT mode
print('Only keep the 4th frame.')
all_img_list = [v for v in all_img_list if v.endswith('im4.png')]
keys = [v for v in keys if v.endswith('_4')]
@ -255,9 +255,9 @@ def vimeo90k(mode):
#### create meta information
meta_info = {}
if mode == 'GT':
if mode == 'hq':
meta_info['name'] = 'Vimeo90K_train_GT'
elif mode == 'LR':
elif mode == 'lq':
meta_info['name'] = 'Vimeo90K_train_LR'
elif mode == 'flow':
meta_info['name'] = 'Vimeo90K_train_flowx4'

View File

@ -163,7 +163,7 @@ if __name__ == "__main__":
torch.backends.cudnn.benchmark = True
srg_analyze = False
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../experiments/train_exd_imgset_srflow/train_exd_imgset_srflow.yml')
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/train_exd_imgsetext_srflow_frompsnr.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt
@ -180,10 +180,10 @@ if __name__ == "__main__":
gen = model.networks['generator']
gen.eval()
mode = "temperature" # temperature | restore | latent_transfer | feed_through
mode = "restore" # temperature | restore | latent_transfer | feed_through
#imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\val2\\lr\\*"
#imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\*"
imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half\\*lanette*"
imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\*"
#imgs_to_resample_pattern = "F:\\4k6k\\datasets\\ns_images\\imagesets\\images-half\\*lanette*"
scale = 2
resample_factor = 1 # When != 1, the HR image is upsampled by this factor using a bicubic to get the local latents.
temperature = 1
@ -224,6 +224,11 @@ if __name__ == "__main__":
t = image_2_tensor(img_file).to(model.env['device'])
if resample_factor != 1:
t = F.interpolate(t, scale_factor=resample_factor, mode="bicubic")
# Ensure the input image is a factor of 16.
_, _, h, w = t.shape
h = 16 * (h // 16)
w = 16 * (w // 16)
t = t[:, :, :h, :w]
resample_img = t
# Fetch the latent metrics & latents for each image we are resampling.
@ -255,6 +260,6 @@ if __name__ == "__main__":
for j in range(len(lats)):
path = os.path.join(output_path, "%i_%i" % (im_it, j))
os.makedirs(path, exist_ok=True)
torchvision.utils.save_image(resample_img, os.path.join(path, "%i_orig.jpg" %(im_it)))
torchvision.utils.save_image(resample_img, os.path.join(path, "orig.jpg" %(im_it)))
create_interpolation_video(gen, F.interpolate(resample_img, scale_factor=1/scale, mode="area"),
path, [torch.zeros_like(l) for l in lats[j]], lats[j])

View File

@ -85,8 +85,8 @@ def main():
if dataset == 'REDS' or dataset == 'Vimeo90K':
LQs = data['LQs']
else:
LQ = data['LQ']
GT = data['GT']
LQ = data['lq']
GT = data['hq']
if dataset == 'REDS' or dataset == 'Vimeo90K':
for j in range(LQs.size(1)):

View File

@ -68,6 +68,6 @@ if __name__ == "__main__":
# removed += 1
imname = osp.basename(data['GT_path'][i])
if results[i]-dataset_mean > 1:
torchvision.utils.save_image(data['GT'][i], osp.join(bin_path, imname))
torchvision.utils.save_image(data['hq'][i], osp.join(bin_path, imname))
print("Removed %i/%i images" % (removed, len(test_set)))

View File

@ -66,7 +66,7 @@ if __name__ == "__main__":
model.test()
gen = model.eval_state['gen'][0].to(model.env['device'])
feagen = netF(gen)
feareal = netF(data['GT'].to(model.env['device']))
feareal = netF(data['hq'].to(model.env['device']))
losses = torch.sum(torch.abs(feareal - feagen), dim=(1,2,3))
means.append(torch.mean(losses).item())
#print(sum(means)/len(means), torch.mean(losses), torch.max(losses), torch.min(losses))
@ -76,6 +76,6 @@ if __name__ == "__main__":
removed += 1
#imname = osp.basename(data['GT_path'][i])
#if losses[i] < 25000:
# torchvision.utils.save_image(data['GT'][i], osp.join(bin_path, imname))
# torchvision.utils.save_image(data['hq'][i], osp.join(bin_path, imname))
print("Removed %i/%i images" % (removed, len(test_set)))

View File

@ -41,9 +41,9 @@ def forward_pass(model, output_dir, alteration_suffix=''):
save_img_path = osp.join(output_dir, img_name + '.png')
if need_GT:
fea_loss += model.compute_fea_loss(visuals[i], data['GT'][i])
fea_loss += model.compute_fea_loss(visuals[i], data['hq'][i])
psnr_sr = util.tensor2img(visuals[i])
psnr_gt = util.tensor2img(data['GT'][i])
psnr_gt = util.tensor2img(data['hq'][i])
psnr_loss += util.calculate_psnr(psnr_sr, psnr_gt)
util.save_img(sr_img, save_img_path)

View File

@ -231,13 +231,13 @@ class Trainer:
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
# calculate PSNR
if self.val_compute_psnr:
gt_img = util.tensor2img(visuals['GT'][b]) # uint8
gt_img = util.tensor2img(visuals['hq'][b]) # uint8
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
avg_psnr += util.calculate_psnr(sr_img, gt_img)
# calculate fea loss
if self.val_compute_fea:
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['hq'][b])
# Save SR images for reference
img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step)

View File

@ -3,6 +3,8 @@ import math
import argparse
import random
import logging
import torchvision
from tqdm import tqdm
import torch
@ -231,18 +233,18 @@ class Trainer:
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
# calculate PSNR
if self.val_compute_psnr:
gt_img = util.tensor2img(visuals['GT'][b]) # uint8
gt_img = util.tensor2img(visuals['hq'][b]) # uint8
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
avg_psnr += util.calculate_psnr(sr_img, gt_img)
# calculate fea loss
if self.val_compute_fea:
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['hq'][b])
# Save SR images for reference
img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step)
save_img_path = os.path.join(img_dir, img_base_name)
util.save_img(sr_img, save_img_path)
torchvision.utils.save_image(visuals['rlt'], save_img_path)
avg_psnr = avg_psnr / idx
avg_fea_loss = avg_fea_loss / idx
@ -291,7 +293,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_srflow_frompsnr.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_srflow_bigboi_frompsnr.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()