diff --git a/codes/data/torch_dataset.py b/codes/data/torch_dataset.py index ca259e8f..0917483c 100644 --- a/codes/data/torch_dataset.py +++ b/codes/data/torch_dataset.py @@ -35,15 +35,16 @@ class TorchDataset(Dataset): ] transforms = T.Compose(transforms) self.dataset = DATASET_MAP[opt['dataset']](transform=transforms, **opt['kwargs']) - self.len = opt['fixed_len'] if 'fixed_len' in opt.keys() else len(self.dataset) + self.len = opt_get(opt, ['fixed_len'], len(self.dataset)) + self.offset = opt_get(opt, ['offset'], 0) def __getitem__(self, item): - underlying_item, lbl = self.dataset[item] + underlying_item, lbl = self.dataset[item+self.offset] return {'lq': underlying_item, 'hq': underlying_item, 'labels': lbl, 'LQ_path': str(item), 'GT_path': str(item)} def __len__(self): - return self.len + return self.len-self.offset if __name__ == '__main__': opt = {