2019-08-23 13:42:47 +00:00
|
|
|
import logging
|
|
|
|
logger = logging.getLogger('base')
|
|
|
|
|
|
|
|
|
|
|
|
def create_model(opt):
|
|
|
|
model = opt['model']
|
|
|
|
# image restoration
|
|
|
|
if model == 'sr': # PSNR-oriented super resolution
|
|
|
|
from .SR_model import SRModel as M
|
2020-08-02 18:55:08 +00:00
|
|
|
elif model == 'srgan' or model == 'corruptgan' or model == 'spsrgan':
|
2019-08-23 13:42:47 +00:00
|
|
|
from .SRGAN_model import SRGANModel as M
|
2020-07-31 17:20:39 +00:00
|
|
|
elif model == 'feat':
|
|
|
|
from .feature_model import FeatureModel as M
|
2020-08-02 18:55:08 +00:00
|
|
|
elif model == 'spsr':
|
2020-08-02 04:02:54 +00:00
|
|
|
from .SPSR_model import SPSRModel as M
|
2020-08-22 14:24:34 +00:00
|
|
|
elif model == 'extensibletrainer':
|
|
|
|
from .ExtensibleTrainer import ExtensibleTrainer as M
|
2019-08-23 13:42:47 +00:00
|
|
|
else:
|
|
|
|
raise NotImplementedError('Model [{:s}] not recognized.'.format(model))
|
|
|
|
m = M(opt)
|
|
|
|
logger.info('Model [{:s}] is created.'.format(m.__class__.__name__))
|
|
|
|
return m
|