forked from mrq/DL-Art-School
Use amp in SR_model for inference
This commit is contained in:
parent
dbca0d328c
commit
03351182be
|
@ -8,6 +8,7 @@ import models.networks as networks
|
|||
import models.lr_scheduler as lr_scheduler
|
||||
from .base_model import BaseModel
|
||||
from models.loss import CharbonnierLoss
|
||||
from apex import amp
|
||||
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
|
@ -23,7 +24,7 @@ class SRModel(BaseModel):
|
|||
train_opt = opt['train']
|
||||
|
||||
# define network and load pretrained models
|
||||
self.netG = networks.define_G(opt).to(self.device)
|
||||
self.netG = amp.initialize(networks.define_G(opt).to(self.device), opt_level=self.amp_level)
|
||||
if opt['dist']:
|
||||
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
|
||||
elif opt['gpu_ids'] is not None:
|
||||
|
|
Loading…
Reference in New Issue
Block a user