DL-Art-School/codes/data/__init__.py
James Betker 44b89330c2 Support inference across batches, support inference on cpu, checkpoint
This is a checkpoint of a set of long tests with reduced-complexity networks. Some takeaways:
1) A full GAN using the resnet discriminator does appear to converge, but the quality is capped.
2) Likewise, a combination GAN/feature loss does not converge. The feature loss is optimized but
    the model appears unable to fight the discriminator, so the G-loss steadily increases.

Going forwards, I want to try some bigger models. In particular, I want to change the generator
to increase complexity and capacity. I also want to add skip connections between the
disc and generator.
2020-05-04 08:48:25 -06:00

54 lines
2.1 KiB
Python

"""create dataset and dataloader"""
import logging
import torch
import torch.utils.data
def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
phase = dataset_opt['phase']
if phase == 'train':
if opt['dist']:
world_size = torch.distributed.get_world_size()
num_workers = dataset_opt['n_workers']
assert dataset_opt['batch_size'] % world_size == 0
batch_size = dataset_opt['batch_size'] // world_size
shuffle = False
else:
num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids'])
batch_size = dataset_opt['batch_size']
shuffle = True
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
num_workers=num_workers, sampler=sampler, drop_last=True,
pin_memory=False)
else:
batch_size = dataset_opt['batch_size'] or 1
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0,
pin_memory=False)
def create_dataset(dataset_opt):
mode = dataset_opt['mode']
# datasets for image restoration
if mode == 'LQ':
from data.LQ_dataset import LQDataset as D
elif mode == 'LQGT':
from data.LQGT_dataset import LQGTDataset as D
# datasets for image corruption
elif mode == 'downsample':
from data.Downsample_dataset import DownsampleDataset as D
# datasets for video restoration
elif mode == 'REDS':
from data.REDS_dataset import REDSDataset as D
elif mode == 'Vimeo90K':
from data.Vimeo90K_dataset import Vimeo90KDataset as D
elif mode == 'video_test':
from data.video_test_dataset import VideoTestDataset as D
else:
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
dataset = D(dataset_opt)
logger = logging.getLogger('base')
logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__,
dataset_opt['name']))
return dataset