diff --git a/codes/data/LQGT_dataset.py b/codes/data/LQGT_dataset.py index b8321963..47a2859d 100644 --- a/codes/data/LQGT_dataset.py +++ b/codes/data/LQGT_dataset.py @@ -181,7 +181,7 @@ class LQGTDataset(data.Dataset): img_LQ = Image.open(corruption_buffer) if self.opt['grayscale']: - img_LQ = ImageOps.grayscale(img_LQ) + img_LQ = ImageOps.grayscale(img_LQ).convert('RGB') img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() img_PIX = torch.from_numpy(np.ascontiguousarray(np.transpose(img_PIX, (2, 0, 1)))).float()