Allow validation for ce
This commit is contained in:
parent
7c251af7a8
commit
e6c537824a
|
@ -35,15 +35,16 @@ class TorchDataset(Dataset):
|
|||
]
|
||||
transforms = T.Compose(transforms)
|
||||
self.dataset = DATASET_MAP[opt['dataset']](transform=transforms, **opt['kwargs'])
|
||||
self.len = opt['fixed_len'] if 'fixed_len' in opt.keys() else len(self.dataset)
|
||||
self.len = opt_get(opt, ['fixed_len'], len(self.dataset))
|
||||
self.offset = opt_get(opt, ['offset'], 0)
|
||||
|
||||
def __getitem__(self, item):
|
||||
underlying_item, lbl = self.dataset[item]
|
||||
underlying_item, lbl = self.dataset[item+self.offset]
|
||||
return {'lq': underlying_item, 'hq': underlying_item, 'labels': lbl,
|
||||
'LQ_path': str(item), 'GT_path': str(item)}
|
||||
|
||||
def __len__(self):
|
||||
return self.len
|
||||
return self.len-self.offset
|
||||
|
||||
if __name__ == '__main__':
|
||||
opt = {
|
||||
|
|
Loading…
Reference in New Issue
Block a user