Add combined dataset for training across multiple datasets
parent
313424d7b5
commit
5189b11dac
@ -0,0 +1,34 @@
|
||||
import torch
|
||||
from data import create_dataset
|
||||
|
||||
|
||||
# Simple composite dataset that combines multiple other datasets.
|
||||
# Assumes that the datasets output dicts.
|
||||
class CombinedDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, opt):
|
||||
self.datasets = {}
|
||||
for k, v in opt.items():
|
||||
if not isinstance(v, dict):
|
||||
continue
|
||||
# Scale&phase gets injected by options.py..
|
||||
v['scale'] = opt['scale']
|
||||
v['phase'] = opt['phase']
|
||||
self.datasets[k] = create_dataset(v)
|
||||
self.items_fetched = 0
|
||||
|
||||
def __getitem__(self, i):
|
||||
self.items_fetched += 1
|
||||
output = {}
|
||||
for name, dataset in self.datasets.items():
|
||||
prefix = ""
|
||||
# 'default' dataset gets no prefix, other ones get `key_`
|
||||
if name != 'default':
|
||||
prefix = name + "_"
|
||||
|
||||
data = dataset[i % len(dataset)]
|
||||
for k, v in data.items():
|
||||
output[prefix + k] = v
|
||||
return output
|
||||
|
||||
def __len__(self):
|
||||
return max(len(d) for d in self.datasets.values())
|
Loading…
Reference in New Issue