forked from mrq/DL-Art-School
Add combined dataset for training across multiple datasets
This commit is contained in:
parent
313424d7b5
commit
5189b11dac
|
@ -14,7 +14,7 @@ def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
|
||||||
batch_size = dataset_opt['batch_size'] // world_size
|
batch_size = dataset_opt['batch_size'] // world_size
|
||||||
shuffle = False
|
shuffle = False
|
||||||
else:
|
else:
|
||||||
num_workers = max(dataset_opt['n_workers'] * len(opt['gpu_ids']), 10)
|
num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids'])
|
||||||
batch_size = dataset_opt['batch_size']
|
batch_size = dataset_opt['batch_size']
|
||||||
shuffle = True
|
shuffle = True
|
||||||
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
|
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
|
||||||
|
@ -38,6 +38,8 @@ def create_dataset(dataset_opt):
|
||||||
from data.Downsample_dataset import DownsampleDataset as D
|
from data.Downsample_dataset import DownsampleDataset as D
|
||||||
elif mode == 'fullimage':
|
elif mode == 'fullimage':
|
||||||
from data.full_image_dataset import FullImageDataset as D
|
from data.full_image_dataset import FullImageDataset as D
|
||||||
|
elif mode == 'combined':
|
||||||
|
from data.combined_dataset import CombinedDataset as D
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
|
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
|
||||||
dataset = D(dataset_opt)
|
dataset = D(dataset_opt)
|
||||||
|
|
34
codes/data/combined_dataset.py
Normal file
34
codes/data/combined_dataset.py
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
import torch
|
||||||
|
from data import create_dataset
|
||||||
|
|
||||||
|
|
||||||
|
# Simple composite dataset that combines multiple other datasets.
|
||||||
|
# Assumes that the datasets output dicts.
|
||||||
|
class CombinedDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, opt):
|
||||||
|
self.datasets = {}
|
||||||
|
for k, v in opt.items():
|
||||||
|
if not isinstance(v, dict):
|
||||||
|
continue
|
||||||
|
# Scale&phase gets injected by options.py..
|
||||||
|
v['scale'] = opt['scale']
|
||||||
|
v['phase'] = opt['phase']
|
||||||
|
self.datasets[k] = create_dataset(v)
|
||||||
|
self.items_fetched = 0
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
self.items_fetched += 1
|
||||||
|
output = {}
|
||||||
|
for name, dataset in self.datasets.items():
|
||||||
|
prefix = ""
|
||||||
|
# 'default' dataset gets no prefix, other ones get `key_`
|
||||||
|
if name != 'default':
|
||||||
|
prefix = name + "_"
|
||||||
|
|
||||||
|
data = dataset[i % len(dataset)]
|
||||||
|
for k, v in data.items():
|
||||||
|
output[prefix + k] = v
|
||||||
|
return output
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return max(len(d) for d in self.datasets.values())
|
|
@ -91,6 +91,9 @@ class ConfigurableStep(Module):
|
||||||
# Don't do injections tagged with eval unless we are not in train mode.
|
# Don't do injections tagged with eval unless we are not in train mode.
|
||||||
if train and 'eval' in inj.opt.keys() and inj.opt['eval']:
|
if train and 'eval' in inj.opt.keys() and inj.opt['eval']:
|
||||||
continue
|
continue
|
||||||
|
# Likewise, don't do injections tagged with train unless we are not in eval.
|
||||||
|
if not train and 'train' in inj.opt.keys() and inj.opt['train']:
|
||||||
|
continue
|
||||||
injected = inj(local_state)
|
injected = inj(local_state)
|
||||||
local_state.update(injected)
|
local_state.update(injected)
|
||||||
new_state.update(injected)
|
new_state.update(injected)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user