diff --git a/codes/data/__init__.py b/codes/data/__init__.py index ce10a834..01bddc01 100644 --- a/codes/data/__init__.py +++ b/codes/data/__init__.py @@ -21,7 +21,7 @@ def create_dataloader(dataset, dataset_opt, opt=None, sampler=None): num_workers=num_workers, sampler=sampler, drop_last=True, pin_memory=False) else: - return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1, + return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=False) diff --git a/codes/data_scripts/extract_subimages.py b/codes/data_scripts/extract_subimages.py index a8185df2..2849afe2 100644 --- a/codes/data_scripts/extract_subimages.py +++ b/codes/data_scripts/extract_subimages.py @@ -19,17 +19,17 @@ def main(): # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer # compression time. If read raw images during training, use 0 for faster IO speed. if mode == 'single': - opt['input_folder'] = '../../datasets/DIV2K/DIV2K_train_HR' - opt['save_folder'] = '../../datasets/DIV2K/DIV2K800_sub' + opt['input_folder'] = '../../datasets/div2k/DIV2K_train_HR' + opt['save_folder'] = '../../datasets/div2k/DIV2K800_sub' opt['crop_sz'] = 480 # the size of each sub-image opt['step'] = 240 # step of the sliding crop window opt['thres_sz'] = 48 # size threshold extract_signle(opt) elif mode == 'pair': - GT_folder = '../../datasets/DIV2K/DIV2K_train_HR' - LR_folder = '../../datasets/DIV2K/DIV2K_train_LR_bicubic/X4' - save_GT_folder = '../../datasets/DIV2K/DIV2K800_sub' - save_LR_folder = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4' + GT_folder = '../../datasets/div2k/DIV2K_train_HR' + LR_folder = '../../datasets/div2k/DIV2K_train_LR_bicubic/X4' + save_GT_folder = '../../datasets/div2k/DIV2K800_sub' + save_LR_folder = '../../datasets/div2k/DIV2K800_sub_bicLRx4' scale_ratio = 4 crop_sz = 480 # the size of each sub-image (GT) step = 240 # step of the sliding crop window (GT) diff --git a/codes/data_scripts/rename.py b/codes/data_scripts/rename.py index f8a19552..ded86ed4 100644 --- a/codes/data_scripts/rename.py +++ b/codes/data_scripts/rename.py @@ -3,7 +3,7 @@ import glob def main(): - folder = '../../datasets/DIV2K/DIV2K_train_LR_bicubic/X4' + folder = 'datasets/div2k/DIV2K_valid_LR_bicubic/X4' DIV2K(folder) print('Finished.') diff --git a/codes/test.py b/codes/test.py index b07a44b7..39ed79e5 100644 --- a/codes/test.py +++ b/codes/test.py @@ -12,7 +12,7 @@ from models import create_model #### options parser = argparse.ArgumentParser() -parser.add_argument('-opt', type=str, required=True, help='Path to options YMAL file.') +parser.add_argument('-opt', type=str, help='Path to options YMAL file.', default='options/test/test_ESRGAN_vrp.yml') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) diff --git a/codes/train.py b/codes/train.py index c8c29bde..8721d33f 100644 --- a/codes/train.py +++ b/codes/train.py @@ -3,10 +3,11 @@ import math import argparse import random import logging +from tqdm import tqdm import torch -import torch.distributed as dist -import torch.multiprocessing as mp +#import torch.distributed as dist +#import torch.multiprocessing as mp from data.data_sampler import DistIterSampler import options.options as option @@ -28,7 +29,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/finetune_ESRGAN_blacked.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) @@ -138,7 +139,7 @@ def main(): current_step = resume_state['iter'] model.resume_training(resume_state) # handle optimizers and schedulers else: - current_step = 0 + current_step = -1 start_epoch = 0 #### training @@ -146,7 +147,8 @@ def main(): for epoch in range(start_epoch, total_epochs + 1): if opt['dist']: train_sampler.set_epoch(epoch) - for _, train_data in enumerate(train_loader): + tq_ldr = tqdm(train_loader) + for _, train_data in enumerate(tq_ldr): current_step += 1 if current_step > total_iters: break