Allow validation for ce

This commit is contained in:
James Betker 2021-06-04 21:21:04 -06:00
parent 7c251af7a8
commit e6c537824a

View File

@ -35,15 +35,16 @@ class TorchDataset(Dataset):
]
transforms = T.Compose(transforms)
self.dataset = DATASET_MAP[opt['dataset']](transform=transforms, **opt['kwargs'])
self.len = opt['fixed_len'] if 'fixed_len' in opt.keys() else len(self.dataset)
self.len = opt_get(opt, ['fixed_len'], len(self.dataset))
self.offset = opt_get(opt, ['offset'], 0)
def __getitem__(self, item):
underlying_item, lbl = self.dataset[item]
underlying_item, lbl = self.dataset[item+self.offset]
return {'lq': underlying_item, 'hq': underlying_item, 'labels': lbl,
'LQ_path': str(item), 'GT_path': str(item)}
def __len__(self):
return self.len
return self.len-self.offset
if __name__ == '__main__':
opt = {