forked from mrq/DL-Art-School
34 lines
1.1 KiB
Python
34 lines
1.1 KiB
Python
|
import torch
|
||
|
from data import create_dataset
|
||
|
|
||
|
|
||
|
# Simple composite dataset that combines multiple other datasets.
|
||
|
# Assumes that the datasets output dicts.
|
||
|
class CombinedDataset(torch.utils.data.Dataset):
|
||
|
def __init__(self, opt):
|
||
|
self.datasets = {}
|
||
|
for k, v in opt.items():
|
||
|
if not isinstance(v, dict):
|
||
|
continue
|
||
|
# Scale&phase gets injected by options.py..
|
||
|
v['scale'] = opt['scale']
|
||
|
v['phase'] = opt['phase']
|
||
|
self.datasets[k] = create_dataset(v)
|
||
|
self.items_fetched = 0
|
||
|
|
||
|
def __getitem__(self, i):
|
||
|
self.items_fetched += 1
|
||
|
output = {}
|
||
|
for name, dataset in self.datasets.items():
|
||
|
prefix = ""
|
||
|
# 'default' dataset gets no prefix, other ones get `key_`
|
||
|
if name != 'default':
|
||
|
prefix = name + "_"
|
||
|
|
||
|
data = dataset[i % len(dataset)]
|
||
|
for k, v in data.items():
|
||
|
output[prefix + k] = v
|
||
|
return output
|
||
|
|
||
|
def __len__(self):
|
||
|
return max(len(d) for d in self.datasets.values())
|