24792bdb4f
Removed a lot of legacy stuff I have no intent on using again. Plan is to shape this repo into something more extensible (get it? hah!)
135 lines
5.5 KiB
Python
135 lines
5.5 KiB
Python
import os
|
|
import os.path as osp
|
|
import logging
|
|
import yaml
|
|
from utils.util import OrderedYaml
|
|
Loader, Dumper = OrderedYaml()
|
|
|
|
def parse(opt_path, is_train=True):
|
|
with open(opt_path, mode='r') as f:
|
|
opt = yaml.load(f, Loader=Loader)
|
|
# export CUDA_VISIBLE_DEVICES
|
|
if 'gpu_ids' in opt.keys():
|
|
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
|
|
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
|
|
|
|
opt['is_train'] = is_train
|
|
if opt['distortion'] == 'sr' or opt['distortion'] == 'downsample':
|
|
scale = opt['scale']
|
|
|
|
# datasets
|
|
if 'datasets' in opt.keys():
|
|
for phase, dataset in opt['datasets'].items():
|
|
phase = phase.split('_')[0]
|
|
dataset['phase'] = phase
|
|
if opt['distortion'] == 'sr' or opt['distortion'] == 'downsample':
|
|
dataset['scale'] = scale
|
|
is_lmdb = False
|
|
''' LMDB is not supported at this point with the mods I've been making.
|
|
if dataset.get('dataroot_GT', None) is not None:
|
|
dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT'])
|
|
if dataset['dataroot_GT'].endswith('lmdb'):
|
|
is_lmdb = True
|
|
if dataset.get('dataroot_LQ', None) is not None:
|
|
dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ'])
|
|
if dataset['dataroot_LQ'].endswith('lmdb'):
|
|
is_lmdb = True
|
|
'''
|
|
dataset['data_type'] = 'lmdb' if is_lmdb else 'img'
|
|
if dataset['mode'].endswith('mc'): # for memcached
|
|
dataset['data_type'] = 'mc'
|
|
dataset['mode'] = dataset['mode'].replace('_mc', '')
|
|
|
|
# path
|
|
if 'path' in opt.keys():
|
|
for key, path in opt['path'].items():
|
|
if path and key in opt['path'] and key != 'strict_load':
|
|
opt['path'][key] = osp.expanduser(path)
|
|
else:
|
|
opt['path'] = {}
|
|
opt['path']['root'] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir, osp.pardir))
|
|
if is_train:
|
|
experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name'])
|
|
opt['path']['experiments_root'] = experiments_root
|
|
opt['path']['models'] = osp.join(experiments_root, 'models')
|
|
opt['path']['training_state'] = osp.join(experiments_root, 'training_state')
|
|
opt['path']['log'] = experiments_root
|
|
opt['path']['val_images'] = osp.join(experiments_root, 'val_images')
|
|
|
|
# change some options for debug mode
|
|
if 'debug' in opt['name']:
|
|
opt['train']['val_freq'] = 8
|
|
opt['logger']['print_freq'] = 1
|
|
opt['logger']['save_checkpoint_freq'] = 8
|
|
else: # test
|
|
results_root = osp.join(opt['path']['root'], 'results', opt['name'])
|
|
opt['path']['results_root'] = results_root
|
|
opt['path']['log'] = results_root
|
|
|
|
# network
|
|
if opt['distortion'] == 'sr' or opt['distortion'] == 'downsample':
|
|
if 'network_G' in opt.keys():
|
|
opt['network_G']['scale'] = scale
|
|
|
|
return opt
|
|
|
|
|
|
def dict2str(opt, indent_l=1):
|
|
'''dict to string for logger'''
|
|
msg = ''
|
|
for k, v in opt.items():
|
|
if isinstance(v, dict):
|
|
msg += ' ' * (indent_l * 2) + k + ':[\n'
|
|
msg += dict2str(v, indent_l + 1)
|
|
msg += ' ' * (indent_l * 2) + ']\n'
|
|
else:
|
|
msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
|
|
return msg
|
|
|
|
|
|
class NoneDict(dict):
|
|
def __missing__(self, key):
|
|
return None
|
|
|
|
|
|
# convert to NoneDict, which return None for missing key.
|
|
def dict_to_nonedict(opt):
|
|
if isinstance(opt, dict):
|
|
new_opt = dict()
|
|
for key, sub_opt in opt.items():
|
|
new_opt[key] = dict_to_nonedict(sub_opt)
|
|
return NoneDict(**new_opt)
|
|
elif isinstance(opt, list):
|
|
return [dict_to_nonedict(sub_opt) for sub_opt in opt]
|
|
else:
|
|
return opt
|
|
|
|
|
|
def check_resume(opt, resume_iter):
|
|
'''Check resume states and pretrain_model paths'''
|
|
logger = logging.getLogger('base')
|
|
if opt['path']['resume_state']:
|
|
if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get(
|
|
'pretrain_model_D', None) is not None:
|
|
logger.warning('pretrain_model path will be ignored when resuming training.')
|
|
|
|
if opt['model'] == 'extensibletrainer':
|
|
for k in opt['networks'].keys():
|
|
pt_key = 'pretrain_model_%s' % (k,)
|
|
opt['path'][pt_key] = osp.join(opt['path']['models'],
|
|
'{}_{}.pth'.format(resume_iter, k))
|
|
logger.info('Set model [%s] to %s' % (k, opt['path'][pt_key]))
|
|
else:
|
|
opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'],
|
|
'{}_G.pth'.format(resume_iter))
|
|
logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G'])
|
|
if 'gan' in opt['model'] or 'spsr' in opt['model']:
|
|
opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'],
|
|
'{}_D.pth'.format(resume_iter))
|
|
logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D'])
|
|
if 'spsr' in opt['model']:
|
|
opt['path']['pretrain_model_D_grad'] = osp.join(opt['path']['models'],
|
|
'{}_D_grad.pth'.format(resume_iter))
|
|
logger.info('Set [pretrain_model_D_grad] to ' + opt['path']['pretrain_model_D_grad'])
|