forked from mrq/DL-Art-School
Some random fixes/adjustments
This commit is contained in:
parent
2538ca9f33
commit
f4b33b0531
|
@ -21,7 +21,7 @@ def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
|
|||
num_workers=num_workers, sampler=sampler, drop_last=True,
|
||||
pin_memory=False)
|
||||
else:
|
||||
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1,
|
||||
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0,
|
||||
pin_memory=False)
|
||||
|
||||
|
||||
|
|
|
@ -19,17 +19,17 @@ def main():
|
|||
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
|
||||
# compression time. If read raw images during training, use 0 for faster IO speed.
|
||||
if mode == 'single':
|
||||
opt['input_folder'] = '../../datasets/DIV2K/DIV2K_train_HR'
|
||||
opt['save_folder'] = '../../datasets/DIV2K/DIV2K800_sub'
|
||||
opt['input_folder'] = '../../datasets/div2k/DIV2K_train_HR'
|
||||
opt['save_folder'] = '../../datasets/div2k/DIV2K800_sub'
|
||||
opt['crop_sz'] = 480 # the size of each sub-image
|
||||
opt['step'] = 240 # step of the sliding crop window
|
||||
opt['thres_sz'] = 48 # size threshold
|
||||
extract_signle(opt)
|
||||
elif mode == 'pair':
|
||||
GT_folder = '../../datasets/DIV2K/DIV2K_train_HR'
|
||||
LR_folder = '../../datasets/DIV2K/DIV2K_train_LR_bicubic/X4'
|
||||
save_GT_folder = '../../datasets/DIV2K/DIV2K800_sub'
|
||||
save_LR_folder = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4'
|
||||
GT_folder = '../../datasets/div2k/DIV2K_train_HR'
|
||||
LR_folder = '../../datasets/div2k/DIV2K_train_LR_bicubic/X4'
|
||||
save_GT_folder = '../../datasets/div2k/DIV2K800_sub'
|
||||
save_LR_folder = '../../datasets/div2k/DIV2K800_sub_bicLRx4'
|
||||
scale_ratio = 4
|
||||
crop_sz = 480 # the size of each sub-image (GT)
|
||||
step = 240 # step of the sliding crop window (GT)
|
||||
|
|
|
@ -3,7 +3,7 @@ import glob
|
|||
|
||||
|
||||
def main():
|
||||
folder = '../../datasets/DIV2K/DIV2K_train_LR_bicubic/X4'
|
||||
folder = 'datasets/div2k/DIV2K_valid_LR_bicubic/X4'
|
||||
DIV2K(folder)
|
||||
print('Finished.')
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ from models import create_model
|
|||
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, required=True, help='Path to options YMAL file.')
|
||||
parser.add_argument('-opt', type=str, help='Path to options YMAL file.', default='options/test/test_ESRGAN_vrp.yml')
|
||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
|
||||
|
|
|
@ -3,10 +3,11 @@ import math
|
|||
import argparse
|
||||
import random
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
#import torch.distributed as dist
|
||||
#import torch.multiprocessing as mp
|
||||
from data.data_sampler import DistIterSampler
|
||||
|
||||
import options.options as option
|
||||
|
@ -28,7 +29,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/finetune_ESRGAN_blacked.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
@ -138,7 +139,7 @@ def main():
|
|||
current_step = resume_state['iter']
|
||||
model.resume_training(resume_state) # handle optimizers and schedulers
|
||||
else:
|
||||
current_step = 0
|
||||
current_step = -1
|
||||
start_epoch = 0
|
||||
|
||||
#### training
|
||||
|
@ -146,7 +147,8 @@ def main():
|
|||
for epoch in range(start_epoch, total_epochs + 1):
|
||||
if opt['dist']:
|
||||
train_sampler.set_epoch(epoch)
|
||||
for _, train_data in enumerate(train_loader):
|
||||
tq_ldr = tqdm(train_loader)
|
||||
for _, train_data in enumerate(tq_ldr):
|
||||
current_step += 1
|
||||
if current_step > total_iters:
|
||||
break
|
||||
|
|
Loading…
Reference in New Issue
Block a user