Some random fixes/adjustments
This commit is contained in:
parent
2538ca9f33
commit
f4b33b0531
|
@ -21,7 +21,7 @@ def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
|
||||||
num_workers=num_workers, sampler=sampler, drop_last=True,
|
num_workers=num_workers, sampler=sampler, drop_last=True,
|
||||||
pin_memory=False)
|
pin_memory=False)
|
||||||
else:
|
else:
|
||||||
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1,
|
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0,
|
||||||
pin_memory=False)
|
pin_memory=False)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -19,17 +19,17 @@ def main():
|
||||||
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
|
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
|
||||||
# compression time. If read raw images during training, use 0 for faster IO speed.
|
# compression time. If read raw images during training, use 0 for faster IO speed.
|
||||||
if mode == 'single':
|
if mode == 'single':
|
||||||
opt['input_folder'] = '../../datasets/DIV2K/DIV2K_train_HR'
|
opt['input_folder'] = '../../datasets/div2k/DIV2K_train_HR'
|
||||||
opt['save_folder'] = '../../datasets/DIV2K/DIV2K800_sub'
|
opt['save_folder'] = '../../datasets/div2k/DIV2K800_sub'
|
||||||
opt['crop_sz'] = 480 # the size of each sub-image
|
opt['crop_sz'] = 480 # the size of each sub-image
|
||||||
opt['step'] = 240 # step of the sliding crop window
|
opt['step'] = 240 # step of the sliding crop window
|
||||||
opt['thres_sz'] = 48 # size threshold
|
opt['thres_sz'] = 48 # size threshold
|
||||||
extract_signle(opt)
|
extract_signle(opt)
|
||||||
elif mode == 'pair':
|
elif mode == 'pair':
|
||||||
GT_folder = '../../datasets/DIV2K/DIV2K_train_HR'
|
GT_folder = '../../datasets/div2k/DIV2K_train_HR'
|
||||||
LR_folder = '../../datasets/DIV2K/DIV2K_train_LR_bicubic/X4'
|
LR_folder = '../../datasets/div2k/DIV2K_train_LR_bicubic/X4'
|
||||||
save_GT_folder = '../../datasets/DIV2K/DIV2K800_sub'
|
save_GT_folder = '../../datasets/div2k/DIV2K800_sub'
|
||||||
save_LR_folder = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4'
|
save_LR_folder = '../../datasets/div2k/DIV2K800_sub_bicLRx4'
|
||||||
scale_ratio = 4
|
scale_ratio = 4
|
||||||
crop_sz = 480 # the size of each sub-image (GT)
|
crop_sz = 480 # the size of each sub-image (GT)
|
||||||
step = 240 # step of the sliding crop window (GT)
|
step = 240 # step of the sliding crop window (GT)
|
||||||
|
|
|
@ -3,7 +3,7 @@ import glob
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
folder = '../../datasets/DIV2K/DIV2K_train_LR_bicubic/X4'
|
folder = 'datasets/div2k/DIV2K_valid_LR_bicubic/X4'
|
||||||
DIV2K(folder)
|
DIV2K(folder)
|
||||||
print('Finished.')
|
print('Finished.')
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ from models import create_model
|
||||||
|
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, required=True, help='Path to options YMAL file.')
|
parser.add_argument('-opt', type=str, help='Path to options YMAL file.', default='options/test/test_ESRGAN_vrp.yml')
|
||||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||||
opt = option.dict_to_nonedict(opt)
|
opt = option.dict_to_nonedict(opt)
|
||||||
|
|
||||||
|
|
|
@ -3,10 +3,11 @@ import math
|
||||||
import argparse
|
import argparse
|
||||||
import random
|
import random
|
||||||
import logging
|
import logging
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
#import torch.distributed as dist
|
||||||
import torch.multiprocessing as mp
|
#import torch.multiprocessing as mp
|
||||||
from data.data_sampler import DistIterSampler
|
from data.data_sampler import DistIterSampler
|
||||||
|
|
||||||
import options.options as option
|
import options.options as option
|
||||||
|
@ -28,7 +29,7 @@ def init_dist(backend='nccl', **kwargs):
|
||||||
def main():
|
def main():
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/finetune_ESRGAN_blacked.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
@ -138,7 +139,7 @@ def main():
|
||||||
current_step = resume_state['iter']
|
current_step = resume_state['iter']
|
||||||
model.resume_training(resume_state) # handle optimizers and schedulers
|
model.resume_training(resume_state) # handle optimizers and schedulers
|
||||||
else:
|
else:
|
||||||
current_step = 0
|
current_step = -1
|
||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
|
|
||||||
#### training
|
#### training
|
||||||
|
@ -146,7 +147,8 @@ def main():
|
||||||
for epoch in range(start_epoch, total_epochs + 1):
|
for epoch in range(start_epoch, total_epochs + 1):
|
||||||
if opt['dist']:
|
if opt['dist']:
|
||||||
train_sampler.set_epoch(epoch)
|
train_sampler.set_epoch(epoch)
|
||||||
for _, train_data in enumerate(train_loader):
|
tq_ldr = tqdm(train_loader)
|
||||||
|
for _, train_data in enumerate(tq_ldr):
|
||||||
current_step += 1
|
current_step += 1
|
||||||
if current_step > total_iters:
|
if current_step > total_iters:
|
||||||
break
|
break
|
||||||
|
|
Loading…
Reference in New Issue
Block a user