Some random fixes/adjustments

pull/9/head
James Betker 2020-04-22 00:38:53 +07:00
parent 2538ca9f33
commit f4b33b0531
5 changed files with 16 additions and 14 deletions

@ -21,7 +21,7 @@ def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
num_workers=num_workers, sampler=sampler, drop_last=True,
pin_memory=False)
else:
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1,
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0,
pin_memory=False)

@ -19,17 +19,17 @@ def main():
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
# compression time. If read raw images during training, use 0 for faster IO speed.
if mode == 'single':
opt['input_folder'] = '../../datasets/DIV2K/DIV2K_train_HR'
opt['save_folder'] = '../../datasets/DIV2K/DIV2K800_sub'
opt['input_folder'] = '../../datasets/div2k/DIV2K_train_HR'
opt['save_folder'] = '../../datasets/div2k/DIV2K800_sub'
opt['crop_sz'] = 480 # the size of each sub-image
opt['step'] = 240 # step of the sliding crop window
opt['thres_sz'] = 48 # size threshold
extract_signle(opt)
elif mode == 'pair':
GT_folder = '../../datasets/DIV2K/DIV2K_train_HR'
LR_folder = '../../datasets/DIV2K/DIV2K_train_LR_bicubic/X4'
save_GT_folder = '../../datasets/DIV2K/DIV2K800_sub'
save_LR_folder = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4'
GT_folder = '../../datasets/div2k/DIV2K_train_HR'
LR_folder = '../../datasets/div2k/DIV2K_train_LR_bicubic/X4'
save_GT_folder = '../../datasets/div2k/DIV2K800_sub'
save_LR_folder = '../../datasets/div2k/DIV2K800_sub_bicLRx4'
scale_ratio = 4
crop_sz = 480 # the size of each sub-image (GT)
step = 240 # step of the sliding crop window (GT)

@ -3,7 +3,7 @@ import glob
def main():
folder = '../../datasets/DIV2K/DIV2K_train_LR_bicubic/X4'
folder = 'datasets/div2k/DIV2K_valid_LR_bicubic/X4'
DIV2K(folder)
print('Finished.')

@ -12,7 +12,7 @@ from models import create_model
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, required=True, help='Path to options YMAL file.')
parser.add_argument('-opt', type=str, help='Path to options YMAL file.', default='options/test/test_ESRGAN_vrp.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)

@ -3,10 +3,11 @@ import math
import argparse
import random
import logging
from tqdm import tqdm
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
#import torch.distributed as dist
#import torch.multiprocessing as mp
from data.data_sampler import DistIterSampler
import options.options as option
@ -28,7 +29,7 @@ def init_dist(backend='nccl', **kwargs):
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/finetune_ESRGAN_blacked.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
@ -138,7 +139,7 @@ def main():
current_step = resume_state['iter']
model.resume_training(resume_state) # handle optimizers and schedulers
else:
current_step = 0
current_step = -1
start_epoch = 0
#### training
@ -146,7 +147,8 @@ def main():
for epoch in range(start_epoch, total_epochs + 1):
if opt['dist']:
train_sampler.set_epoch(epoch)
for _, train_data in enumerate(train_loader):
tq_ldr = tqdm(train_loader)
for _, train_data in enumerate(tq_ldr):
current_step += 1
if current_step > total_iters:
break