Use amp in SR_model for inference

This commit is contained in:
James Betker 2020-05-07 21:45:33 -06:00
parent dbca0d328c
commit 03351182be

View File

@ -8,6 +8,7 @@ import models.networks as networks
import models.lr_scheduler as lr_scheduler import models.lr_scheduler as lr_scheduler
from .base_model import BaseModel from .base_model import BaseModel
from models.loss import CharbonnierLoss from models.loss import CharbonnierLoss
from apex import amp
logger = logging.getLogger('base') logger = logging.getLogger('base')
@ -23,7 +24,7 @@ class SRModel(BaseModel):
train_opt = opt['train'] train_opt = opt['train']
# define network and load pretrained models # define network and load pretrained models
self.netG = networks.define_G(opt).to(self.device) self.netG = amp.initialize(networks.define_G(opt).to(self.device), opt_level=self.amp_level)
if opt['dist']: if opt['dist']:
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
elif opt['gpu_ids'] is not None: elif opt['gpu_ids'] is not None: