DL-Art-School/codes/data_scripts/validate_data.py
2020-08-22 08:24:34 -06:00

67 lines
2.3 KiB
Python

# This script iterates through all the data with no worker threads and performs whatever transformations are prescribed.
# The idea is to find bad/corrupt images.
import math
import argparse
import random
import torch
import options.options as option
from utils import util
from data import create_dataloader, create_dataset
from time import time
from tqdm import tqdm
from skimage import io
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../../options/train_mi1_spsr_switched2.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)
#### distributed training settings
opt['dist'] = False
rank = -1
# convert to NoneDict, which returns None for missing keys
opt = option.dict_to_nonedict(opt)
#### random seed
seed = opt['train']['manual_seed']
if seed is None:
seed = random.randint(1, 10000)
util.set_random_seed(seed)
torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = True
#### create train and val dataloader
for phase, dataset_opt in opt['datasets'].items():
if phase == 'train':
train_set = create_dataset(dataset_opt)
train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
total_iters = int(opt['train']['niter'])
total_epochs = int(math.ceil(total_iters / train_size))
dataset_opt['n_workers'] = 0 # Force num_workers=0 to make dataloader work in process.
train_loader = create_dataloader(train_set, dataset_opt, opt, None)
if rank <= 0:
print('Number of train images: {:,d}, iters: {:,d}'.format(
len(train_set), train_size))
assert train_loader is not None
tq_ldr = tqdm(train_set.paths_GT)
for path in tq_ldr:
try:
_ = io.imread(path)
# Do stuff with img
except Exception as e:
print("Error with %s" % (path,))
print(e)
if __name__ == '__main__':
main()