Allow dataset classes to add noise internally
This commit is contained in:
parent
af1968f9e5
commit
74bb0fad33
|
@ -5,6 +5,9 @@ import lmdb
|
|||
import torch
|
||||
import torch.utils.data as data
|
||||
import data.util as util
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
import torchvision.transforms.functional as F
|
||||
|
||||
|
||||
class DownsampleDataset(data.Dataset):
|
||||
|
@ -64,10 +67,6 @@ class DownsampleDataset(data.Dataset):
|
|||
] if self.data_type == 'lmdb' else None
|
||||
img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
|
||||
|
||||
# Create a downsampled version of the HQ image using matlab imresize.
|
||||
img_Downsampled = util.imresize_np(img_GT, 1 / scale)
|
||||
assert img_Downsampled.ndim == 3
|
||||
|
||||
if self.opt['phase'] == 'train':
|
||||
H, W, _ = img_GT.shape
|
||||
assert H >= GT_size and W >= GT_size
|
||||
|
@ -80,29 +79,36 @@ class DownsampleDataset(data.Dataset):
|
|||
rnd_h = random.randint(0, max(0, H - LQ_size))
|
||||
rnd_w = random.randint(0, max(0, W - LQ_size))
|
||||
img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
|
||||
img_Downsampled = img_Downsampled[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
|
||||
rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale)
|
||||
img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
|
||||
else:
|
||||
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
|
||||
img_Downsampled = cv2.resize(img_Downsampled, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
|
||||
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
# augmentation - flip, rotate
|
||||
img_LQ, img_GT, img_Downsampled = util.augment([img_LQ, img_GT, img_Downsampled], self.opt['use_flip'],
|
||||
img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'],
|
||||
self.opt['use_rot'])
|
||||
|
||||
if self.opt['color']: # change color space if necessary
|
||||
img_Downsampled = util.channel_convert(C, self.opt['color'],
|
||||
[img_Downsampled])[0] # TODO during val no definition
|
||||
|
||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||
if img_GT.shape[2] == 3:
|
||||
img_GT = img_GT[:, :, [2, 1, 0]]
|
||||
img_LQ = img_LQ[:, :, [2, 1, 0]]
|
||||
img_Downsampled = img_Downsampled[:, :, [2, 1, 0]]
|
||||
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
|
||||
img_Downsampled = torch.from_numpy(np.ascontiguousarray(np.transpose(img_Downsampled, (2, 0, 1)))).float()
|
||||
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_BGR2RGB)
|
||||
img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# HQ needs to go to a PIL image to perform the compression-artifact transformation.
|
||||
H, W, _ = img_GT.shape
|
||||
img_GT = (img_GT * 255).astype(np.uint8)
|
||||
img_GT = Image.fromarray(img_GT)
|
||||
if self.opt['use_compression_artifacts']:
|
||||
qf = random.randrange(15, 100)
|
||||
corruption_buffer = BytesIO()
|
||||
img_GT.save(corruption_buffer, "JPEG", quality=qf, optimice=True)
|
||||
corruption_buffer.seek(0)
|
||||
img_GT = Image.open(corruption_buffer)
|
||||
# Generate a downsampled image from HQ for feature and PIX losses.
|
||||
img_Downsampled = F.resize(img_GT, (int(H / scale), int(W / scale)))
|
||||
|
||||
img_GT = F.to_tensor(img_GT)
|
||||
img_Downsampled = F.to_tensor(img_Downsampled)
|
||||
img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
|
||||
|
||||
# This may seem really messed up, but let me explain:
|
||||
|
|
|
@ -5,6 +5,9 @@ import lmdb
|
|||
import torch
|
||||
import torch.utils.data as data
|
||||
import data.util as util
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
import torchvision.transforms.functional as F
|
||||
|
||||
|
||||
class LQGTDataset(data.Dataset):
|
||||
|
@ -27,16 +30,17 @@ class LQGTDataset(data.Dataset):
|
|||
self.LQ_env, self.GT_env, self.PIX_env = None, None, None # environments for lmdbs
|
||||
|
||||
self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'])
|
||||
self.paths_LQ = []
|
||||
if isinstance(opt['dataroot_LQ'], list):
|
||||
# Multiple LQ data sources can be given, in case there are multiple ways of corrupting a source image and
|
||||
# we want the model to learn them all.
|
||||
for dr_lq in opt['dataroot_LQ']:
|
||||
lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, dr_lq)
|
||||
if 'dataroot_LQ' in opt.keys():
|
||||
self.paths_LQ = []
|
||||
if isinstance(opt['dataroot_LQ'], list):
|
||||
# Multiple LQ data sources can be given, in case there are multiple ways of corrupting a source image and
|
||||
# we want the model to learn them all.
|
||||
for dr_lq in opt['dataroot_LQ']:
|
||||
lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, dr_lq)
|
||||
self.paths_LQ.append(lq_path)
|
||||
else:
|
||||
lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
|
||||
self.paths_LQ.append(lq_path)
|
||||
else:
|
||||
lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
|
||||
self.paths_LQ.append(lq_path)
|
||||
self.doCrop = opt['doCrop']
|
||||
if 'dataroot_PIX' in opt.keys():
|
||||
self.paths_PIX, self.sizes_PIX = util.get_image_paths(self.data_type, opt['dataroot_PIX'])
|
||||
|
@ -144,12 +148,23 @@ class LQGTDataset(data.Dataset):
|
|||
|
||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||
if img_GT.shape[2] == 3:
|
||||
img_GT = img_GT[:, :, [2, 1, 0]]
|
||||
img_LQ = img_LQ[:, :, [2, 1, 0]]
|
||||
img_PIX = img_PIX[:, :, [2, 1, 0]]
|
||||
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_BGR2RGB)
|
||||
img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_BGR2RGB)
|
||||
img_PIX = cv2.cvtColor(img_PIX, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# LQ needs to go to a PIL image to perform the compression-artifact transformation.
|
||||
img_LQ = (img_LQ * 255).astype(np.uint8)
|
||||
img_LQ = Image.fromarray(img_LQ)
|
||||
if self.opt['use_compression_artifacts']:
|
||||
qf = random.randrange(15, 100)
|
||||
corruption_buffer = BytesIO()
|
||||
img_LQ.save(corruption_buffer, "JPEG", quality=qf, optimice=True)
|
||||
corruption_buffer.seek(0)
|
||||
img_LQ = Image.open(corruption_buffer)
|
||||
|
||||
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
|
||||
img_PIX = torch.from_numpy(np.ascontiguousarray(np.transpose(img_PIX, (2, 0, 1)))).float()
|
||||
img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
|
||||
img_LQ = F.to_tensor(img_LQ)
|
||||
|
||||
if LQ_path is None:
|
||||
LQ_path = GT_path
|
||||
|
|
Loading…
Reference in New Issue
Block a user