From 5189b11dac31c1eae67bc9c7183d278c0535df4a Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 11 Sep 2020 08:44:06 -0600 Subject: [PATCH] Add combined dataset for training across multiple datasets --- codes/data/__init__.py | 4 +++- codes/data/combined_dataset.py | 34 ++++++++++++++++++++++++++++++++++ codes/models/steps/steps.py | 3 +++ 3 files changed, 40 insertions(+), 1 deletion(-) create mode 100644 codes/data/combined_dataset.py diff --git a/codes/data/__init__.py b/codes/data/__init__.py index e7070974..f3fff285 100644 --- a/codes/data/__init__.py +++ b/codes/data/__init__.py @@ -14,7 +14,7 @@ def create_dataloader(dataset, dataset_opt, opt=None, sampler=None): batch_size = dataset_opt['batch_size'] // world_size shuffle = False else: - num_workers = max(dataset_opt['n_workers'] * len(opt['gpu_ids']), 10) + num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids']) batch_size = dataset_opt['batch_size'] shuffle = True return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, @@ -38,6 +38,8 @@ def create_dataset(dataset_opt): from data.Downsample_dataset import DownsampleDataset as D elif mode == 'fullimage': from data.full_image_dataset import FullImageDataset as D + elif mode == 'combined': + from data.combined_dataset import CombinedDataset as D else: raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) dataset = D(dataset_opt) diff --git a/codes/data/combined_dataset.py b/codes/data/combined_dataset.py new file mode 100644 index 00000000..b91b1b94 --- /dev/null +++ b/codes/data/combined_dataset.py @@ -0,0 +1,34 @@ +import torch +from data import create_dataset + + +# Simple composite dataset that combines multiple other datasets. +# Assumes that the datasets output dicts. +class CombinedDataset(torch.utils.data.Dataset): + def __init__(self, opt): + self.datasets = {} + for k, v in opt.items(): + if not isinstance(v, dict): + continue + # Scale&phase gets injected by options.py.. + v['scale'] = opt['scale'] + v['phase'] = opt['phase'] + self.datasets[k] = create_dataset(v) + self.items_fetched = 0 + + def __getitem__(self, i): + self.items_fetched += 1 + output = {} + for name, dataset in self.datasets.items(): + prefix = "" + # 'default' dataset gets no prefix, other ones get `key_` + if name != 'default': + prefix = name + "_" + + data = dataset[i % len(dataset)] + for k, v in data.items(): + output[prefix + k] = v + return output + + def __len__(self): + return max(len(d) for d in self.datasets.values()) \ No newline at end of file diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index e7de2e8d..d1136883 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -91,6 +91,9 @@ class ConfigurableStep(Module): # Don't do injections tagged with eval unless we are not in train mode. if train and 'eval' in inj.opt.keys() and inj.opt['eval']: continue + # Likewise, don't do injections tagged with train unless we are not in eval. + if not train and 'train' in inj.opt.keys() and inj.opt['train']: + continue injected = inj(local_state) local_state.update(injected) new_state.update(injected)