diff --git a/codes/models/SR_model.py b/codes/models/SR_model.py index 5226715b..bf46ea3f 100644 --- a/codes/models/SR_model.py +++ b/codes/models/SR_model.py @@ -8,6 +8,7 @@ import models.networks as networks import models.lr_scheduler as lr_scheduler from .base_model import BaseModel from models.loss import CharbonnierLoss +from apex import amp logger = logging.getLogger('base') @@ -23,7 +24,7 @@ class SRModel(BaseModel): train_opt = opt['train'] # define network and load pretrained models - self.netG = networks.define_G(opt).to(self.device) + self.netG = amp.initialize(networks.define_G(opt).to(self.device), opt_level=self.amp_level) if opt['dist']: self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) elif opt['gpu_ids'] is not None: