34 lines
1.1 KiB
Python
34 lines
1.1 KiB
Python
import torch
|
|
from data import create_dataset
|
|
|
|
|
|
# Simple composite dataset that combines multiple other datasets.
|
|
# Assumes that the datasets output dicts.
|
|
class CombinedDataset(torch.utils.data.Dataset):
|
|
def __init__(self, opt):
|
|
self.datasets = {}
|
|
for k, v in opt.items():
|
|
if not isinstance(v, dict):
|
|
continue
|
|
# Scale&phase gets injected by options.py..
|
|
v['scale'] = opt['scale']
|
|
v['phase'] = opt['phase']
|
|
self.datasets[k] = create_dataset(v)
|
|
self.items_fetched = 0
|
|
|
|
def __getitem__(self, i):
|
|
self.items_fetched += 1
|
|
output = {}
|
|
for name, dataset in self.datasets.items():
|
|
prefix = ""
|
|
# 'default' dataset gets no prefix, other ones get `key_`
|
|
if name != 'default':
|
|
prefix = name + "_"
|
|
|
|
data = dataset[i % len(dataset)]
|
|
for k, v in data.items():
|
|
output[prefix + k] = v
|
|
return output
|
|
|
|
def __len__(self):
|
|
return max(len(d) for d in self.datasets.values()) |