Allow dataset classes to add noise internally

This commit is contained in:
James Betker 2020-05-23 21:04:24 -06:00
parent af1968f9e5
commit 74bb0fad33
2 changed files with 50 additions and 29 deletions

View File

@ -5,6 +5,9 @@ import lmdb
import torch import torch
import torch.utils.data as data import torch.utils.data as data
import data.util as util import data.util as util
from PIL import Image
from io import BytesIO
import torchvision.transforms.functional as F
class DownsampleDataset(data.Dataset): class DownsampleDataset(data.Dataset):
@ -64,10 +67,6 @@ class DownsampleDataset(data.Dataset):
] if self.data_type == 'lmdb' else None ] if self.data_type == 'lmdb' else None
img_LQ = util.read_img(self.LQ_env, LQ_path, resolution) img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
# Create a downsampled version of the HQ image using matlab imresize.
img_Downsampled = util.imresize_np(img_GT, 1 / scale)
assert img_Downsampled.ndim == 3
if self.opt['phase'] == 'train': if self.opt['phase'] == 'train':
H, W, _ = img_GT.shape H, W, _ = img_GT.shape
assert H >= GT_size and W >= GT_size assert H >= GT_size and W >= GT_size
@ -80,29 +79,36 @@ class DownsampleDataset(data.Dataset):
rnd_h = random.randint(0, max(0, H - LQ_size)) rnd_h = random.randint(0, max(0, H - LQ_size))
rnd_w = random.randint(0, max(0, W - LQ_size)) rnd_w = random.randint(0, max(0, W - LQ_size))
img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
img_Downsampled = img_Downsampled[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale) rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale)
img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
else: else:
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR) img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
img_Downsampled = cv2.resize(img_Downsampled, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
# augmentation - flip, rotate # augmentation - flip, rotate
img_LQ, img_GT, img_Downsampled = util.augment([img_LQ, img_GT, img_Downsampled], self.opt['use_flip'], img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'],
self.opt['use_rot']) self.opt['use_rot'])
if self.opt['color']: # change color space if necessary
img_Downsampled = util.channel_convert(C, self.opt['color'],
[img_Downsampled])[0] # TODO during val no definition
# BGR to RGB, HWC to CHW, numpy to tensor # BGR to RGB, HWC to CHW, numpy to tensor
if img_GT.shape[2] == 3: if img_GT.shape[2] == 3:
img_GT = img_GT[:, :, [2, 1, 0]] img_GT = cv2.cvtColor(img_GT, cv2.COLOR_BGR2RGB)
img_LQ = img_LQ[:, :, [2, 1, 0]] img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_BGR2RGB)
img_Downsampled = img_Downsampled[:, :, [2, 1, 0]]
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() # HQ needs to go to a PIL image to perform the compression-artifact transformation.
img_Downsampled = torch.from_numpy(np.ascontiguousarray(np.transpose(img_Downsampled, (2, 0, 1)))).float() H, W, _ = img_GT.shape
img_GT = (img_GT * 255).astype(np.uint8)
img_GT = Image.fromarray(img_GT)
if self.opt['use_compression_artifacts']:
qf = random.randrange(15, 100)
corruption_buffer = BytesIO()
img_GT.save(corruption_buffer, "JPEG", quality=qf, optimice=True)
corruption_buffer.seek(0)
img_GT = Image.open(corruption_buffer)
# Generate a downsampled image from HQ for feature and PIX losses.
img_Downsampled = F.resize(img_GT, (int(H / scale), int(W / scale)))
img_GT = F.to_tensor(img_GT)
img_Downsampled = F.to_tensor(img_Downsampled)
img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float() img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
# This may seem really messed up, but let me explain: # This may seem really messed up, but let me explain:

View File

@ -5,6 +5,9 @@ import lmdb
import torch import torch
import torch.utils.data as data import torch.utils.data as data
import data.util as util import data.util as util
from PIL import Image
from io import BytesIO
import torchvision.transforms.functional as F
class LQGTDataset(data.Dataset): class LQGTDataset(data.Dataset):
@ -27,16 +30,17 @@ class LQGTDataset(data.Dataset):
self.LQ_env, self.GT_env, self.PIX_env = None, None, None # environments for lmdbs self.LQ_env, self.GT_env, self.PIX_env = None, None, None # environments for lmdbs
self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT']) self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'])
self.paths_LQ = [] if 'dataroot_LQ' in opt.keys():
if isinstance(opt['dataroot_LQ'], list): self.paths_LQ = []
# Multiple LQ data sources can be given, in case there are multiple ways of corrupting a source image and if isinstance(opt['dataroot_LQ'], list):
# we want the model to learn them all. # Multiple LQ data sources can be given, in case there are multiple ways of corrupting a source image and
for dr_lq in opt['dataroot_LQ']: # we want the model to learn them all.
lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, dr_lq) for dr_lq in opt['dataroot_LQ']:
lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, dr_lq)
self.paths_LQ.append(lq_path)
else:
lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
self.paths_LQ.append(lq_path) self.paths_LQ.append(lq_path)
else:
lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
self.paths_LQ.append(lq_path)
self.doCrop = opt['doCrop'] self.doCrop = opt['doCrop']
if 'dataroot_PIX' in opt.keys(): if 'dataroot_PIX' in opt.keys():
self.paths_PIX, self.sizes_PIX = util.get_image_paths(self.data_type, opt['dataroot_PIX']) self.paths_PIX, self.sizes_PIX = util.get_image_paths(self.data_type, opt['dataroot_PIX'])
@ -144,12 +148,23 @@ class LQGTDataset(data.Dataset):
# BGR to RGB, HWC to CHW, numpy to tensor # BGR to RGB, HWC to CHW, numpy to tensor
if img_GT.shape[2] == 3: if img_GT.shape[2] == 3:
img_GT = img_GT[:, :, [2, 1, 0]] img_GT = cv2.cvtColor(img_GT, cv2.COLOR_BGR2RGB)
img_LQ = img_LQ[:, :, [2, 1, 0]] img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_BGR2RGB)
img_PIX = img_PIX[:, :, [2, 1, 0]] img_PIX = cv2.cvtColor(img_PIX, cv2.COLOR_BGR2RGB)
# LQ needs to go to a PIL image to perform the compression-artifact transformation.
img_LQ = (img_LQ * 255).astype(np.uint8)
img_LQ = Image.fromarray(img_LQ)
if self.opt['use_compression_artifacts']:
qf = random.randrange(15, 100)
corruption_buffer = BytesIO()
img_LQ.save(corruption_buffer, "JPEG", quality=qf, optimice=True)
corruption_buffer.seek(0)
img_LQ = Image.open(corruption_buffer)
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
img_PIX = torch.from_numpy(np.ascontiguousarray(np.transpose(img_PIX, (2, 0, 1)))).float() img_PIX = torch.from_numpy(np.ascontiguousarray(np.transpose(img_PIX, (2, 0, 1)))).float()
img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float() img_LQ = F.to_tensor(img_LQ)
if LQ_path is None: if LQ_path is None:
LQ_path = GT_path LQ_path = GT_path