diff --git a/codes/data/LQGT_dataset.py b/codes/data/LQGT_dataset.py index a13948a5..b8321963 100644 --- a/codes/data/LQGT_dataset.py +++ b/codes/data/LQGT_dataset.py @@ -5,7 +5,7 @@ import lmdb import torch import torch.utils.data as data import data.util as util -from PIL import Image +from PIL import Image, ImageOps from io import BytesIO import torchvision.transforms.functional as F @@ -180,6 +180,9 @@ class LQGTDataset(data.Dataset): corruption_buffer.seek(0) img_LQ = Image.open(corruption_buffer) + if self.opt['grayscale']: + img_LQ = ImageOps.grayscale(img_LQ) + img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() img_PIX = torch.from_numpy(np.ascontiguousarray(np.transpose(img_PIX, (2, 0, 1)))).float() img_LQ = F.to_tensor(img_LQ)