Some random fixes/adjustments

This commit is contained in:
James Betker 2020-04-22 00:38:53 -06:00
parent 2538ca9f33
commit f4b33b0531
5 changed files with 16 additions and 14 deletions

View File

@ -21,7 +21,7 @@ def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
num_workers=num_workers, sampler=sampler, drop_last=True, num_workers=num_workers, sampler=sampler, drop_last=True,
pin_memory=False) pin_memory=False)
else: else:
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1, return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0,
pin_memory=False) pin_memory=False)

View File

@ -19,17 +19,17 @@ def main():
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
# compression time. If read raw images during training, use 0 for faster IO speed. # compression time. If read raw images during training, use 0 for faster IO speed.
if mode == 'single': if mode == 'single':
opt['input_folder'] = '../../datasets/DIV2K/DIV2K_train_HR' opt['input_folder'] = '../../datasets/div2k/DIV2K_train_HR'
opt['save_folder'] = '../../datasets/DIV2K/DIV2K800_sub' opt['save_folder'] = '../../datasets/div2k/DIV2K800_sub'
opt['crop_sz'] = 480 # the size of each sub-image opt['crop_sz'] = 480 # the size of each sub-image
opt['step'] = 240 # step of the sliding crop window opt['step'] = 240 # step of the sliding crop window
opt['thres_sz'] = 48 # size threshold opt['thres_sz'] = 48 # size threshold
extract_signle(opt) extract_signle(opt)
elif mode == 'pair': elif mode == 'pair':
GT_folder = '../../datasets/DIV2K/DIV2K_train_HR' GT_folder = '../../datasets/div2k/DIV2K_train_HR'
LR_folder = '../../datasets/DIV2K/DIV2K_train_LR_bicubic/X4' LR_folder = '../../datasets/div2k/DIV2K_train_LR_bicubic/X4'
save_GT_folder = '../../datasets/DIV2K/DIV2K800_sub' save_GT_folder = '../../datasets/div2k/DIV2K800_sub'
save_LR_folder = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4' save_LR_folder = '../../datasets/div2k/DIV2K800_sub_bicLRx4'
scale_ratio = 4 scale_ratio = 4
crop_sz = 480 # the size of each sub-image (GT) crop_sz = 480 # the size of each sub-image (GT)
step = 240 # step of the sliding crop window (GT) step = 240 # step of the sliding crop window (GT)

View File

@ -3,7 +3,7 @@ import glob
def main(): def main():
folder = '../../datasets/DIV2K/DIV2K_train_LR_bicubic/X4' folder = 'datasets/div2k/DIV2K_valid_LR_bicubic/X4'
DIV2K(folder) DIV2K(folder)
print('Finished.') print('Finished.')

View File

@ -12,7 +12,7 @@ from models import create_model
#### options #### options
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, required=True, help='Path to options YMAL file.') parser.add_argument('-opt', type=str, help='Path to options YMAL file.', default='options/test/test_ESRGAN_vrp.yml')
opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt) opt = option.dict_to_nonedict(opt)

View File

@ -3,10 +3,11 @@ import math
import argparse import argparse
import random import random
import logging import logging
from tqdm import tqdm
import torch import torch
import torch.distributed as dist #import torch.distributed as dist
import torch.multiprocessing as mp #import torch.multiprocessing as mp
from data.data_sampler import DistIterSampler from data.data_sampler import DistIterSampler
import options.options as option import options.options as option
@ -28,7 +29,7 @@ def init_dist(backend='nccl', **kwargs):
def main(): def main():
#### options #### options
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/finetune_ESRGAN_blacked.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher') help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
@ -138,7 +139,7 @@ def main():
current_step = resume_state['iter'] current_step = resume_state['iter']
model.resume_training(resume_state) # handle optimizers and schedulers model.resume_training(resume_state) # handle optimizers and schedulers
else: else:
current_step = 0 current_step = -1
start_epoch = 0 start_epoch = 0
#### training #### training
@ -146,7 +147,8 @@ def main():
for epoch in range(start_epoch, total_epochs + 1): for epoch in range(start_epoch, total_epochs + 1):
if opt['dist']: if opt['dist']:
train_sampler.set_epoch(epoch) train_sampler.set_epoch(epoch)
for _, train_data in enumerate(train_loader): tq_ldr = tqdm(train_loader)
for _, train_data in enumerate(tq_ldr):
current_step += 1 current_step += 1
if current_step > total_iters: if current_step > total_iters:
break break