forked from mrq/DL-Art-School
8f18b2709e
It is not clear to me what the purpose of this is, but it has recently started causing failures.
121 lines
4.9 KiB
Python
121 lines
4.9 KiB
Python
import os
|
|
import os.path as osp
|
|
import logging
|
|
import yaml
|
|
from utils.util import OrderedYaml
|
|
Loader, Dumper = OrderedYaml()
|
|
|
|
def parse(opt_path, is_train=True):
|
|
with open(opt_path, mode='r') as f:
|
|
opt = yaml.load(f, Loader=Loader)
|
|
|
|
opt['is_train'] = is_train
|
|
|
|
# datasets
|
|
if 'datasets' in opt.keys():
|
|
for phase, dataset in opt['datasets'].items():
|
|
phase = phase.split('_')[0]
|
|
dataset['phase'] = phase
|
|
is_lmdb = False
|
|
''' LMDB is not supported at this point with the mods I've been making.
|
|
if dataset.get('dataroot_GT', None) is not None:
|
|
dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT'])
|
|
if dataset['dataroot_GT'].endswith('lmdb'):
|
|
is_lmdb = True
|
|
if dataset.get('dataroot_LQ', None) is not None:
|
|
dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ'])
|
|
if dataset['dataroot_LQ'].endswith('lmdb'):
|
|
is_lmdb = True
|
|
'''
|
|
dataset['data_type'] = 'lmdb' if is_lmdb else 'img'
|
|
if dataset['mode'].endswith('mc'): # for memcached
|
|
dataset['data_type'] = 'mc'
|
|
dataset['mode'] = dataset['mode'].replace('_mc', '')
|
|
|
|
# path
|
|
if 'path' in opt.keys():
|
|
for key, path in opt['path'].items():
|
|
if path and key in opt['path'] and key != 'strict_load':
|
|
opt['path'][key] = osp.expanduser(path)
|
|
else:
|
|
opt['path'] = {}
|
|
opt['path']['root'] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir, osp.pardir))
|
|
if is_train:
|
|
experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name'])
|
|
opt['path']['experiments_root'] = experiments_root
|
|
opt['path']['models'] = osp.join(experiments_root, 'models')
|
|
opt['path']['training_state'] = osp.join(experiments_root, 'training_state')
|
|
opt['path']['log'] = experiments_root
|
|
opt['path']['val_images'] = osp.join(experiments_root, 'val_images')
|
|
|
|
# change some options for debug mode
|
|
if 'debug' in opt['name']:
|
|
opt['train']['val_freq'] = 8
|
|
opt['logger']['print_freq'] = 1
|
|
opt['logger']['save_checkpoint_freq'] = 8
|
|
else: # test
|
|
results_root = osp.join(opt['path']['root'], 'results', opt['name'])
|
|
opt['path']['results_root'] = results_root
|
|
opt['path']['log'] = results_root
|
|
|
|
return opt
|
|
|
|
|
|
def dict2str(opt, indent_l=1):
|
|
'''dict to string for logger'''
|
|
msg = ''
|
|
for k, v in opt.items():
|
|
if isinstance(v, dict):
|
|
msg += ' ' * (indent_l * 2) + k + ':[\n'
|
|
msg += dict2str(v, indent_l + 1)
|
|
msg += ' ' * (indent_l * 2) + ']\n'
|
|
else:
|
|
msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
|
|
return msg
|
|
|
|
|
|
class NoneDict(dict):
|
|
def __missing__(self, key):
|
|
return None
|
|
|
|
|
|
# convert to NoneDict, which return None for missing key.
|
|
def dict_to_nonedict(opt):
|
|
if isinstance(opt, dict):
|
|
new_opt = dict()
|
|
for key, sub_opt in opt.items():
|
|
new_opt[key] = dict_to_nonedict(sub_opt)
|
|
return NoneDict(**new_opt)
|
|
elif isinstance(opt, list):
|
|
return [dict_to_nonedict(sub_opt) for sub_opt in opt]
|
|
else:
|
|
return opt
|
|
|
|
|
|
def check_resume(opt, resume_iter):
|
|
'''Check resume states and pretrain_model paths'''
|
|
logger = logging.getLogger('base')
|
|
if opt['path']['resume_state']:
|
|
if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get(
|
|
'pretrain_model_D', None) is not None:
|
|
logger.warning('pretrain_model path will be ignored when resuming training.')
|
|
|
|
if opt['model'] == 'extensibletrainer':
|
|
for k in opt['networks'].keys():
|
|
pt_key = 'pretrain_model_%s' % (k,)
|
|
opt['path'][pt_key] = osp.join(opt['path']['models'],
|
|
'{}_{}.pth'.format(resume_iter, k))
|
|
logger.info('Set model [%s] to %s' % (k, opt['path'][pt_key]))
|
|
else:
|
|
opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'],
|
|
'{}_G.pth'.format(resume_iter))
|
|
logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G'])
|
|
if 'gan' in opt['model'] or 'spsr' in opt['model']:
|
|
opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'],
|
|
'{}_D.pth'.format(resume_iter))
|
|
logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D'])
|
|
if 'spsr' in opt['model']:
|
|
opt['path']['pretrain_model_D_grad'] = osp.join(opt['path']['models'],
|
|
'{}_D_grad.pth'.format(resume_iter))
|
|
logger.info('Set [pretrain_model_D_grad] to ' + opt['path']['pretrain_model_D_grad'])
|