diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 051f5076..40241f8b 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -7,6 +7,10 @@ import models.networks as networks import models.lr_scheduler as lr_scheduler from .base_model import BaseModel from models.loss import GANLoss +from apex import amp + +import torchvision.utils as utils +import os logger = logging.getLogger('base') @@ -101,6 +105,10 @@ class SRGANModel(BaseModel): betas=(train_opt['beta1_D'], train_opt['beta2_D'])) self.optimizers.append(self.optimizer_D) + # AMP + [self.netG, self.netD], [self.optimizer_G, self.optimizer_D] = \ + amp.initialize([self.netG, self.netD], [self.optimizer_G, self.optimizer_D], opt_level=self.amp_level, num_losses=3) + # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: @@ -132,6 +140,13 @@ class SRGANModel(BaseModel): self.var_ref = input_ref.to(self.device) def optimize_parameters(self, step): + + if step % 50 == 0: + for i in range(self.var_L.shape[0]): + utils.save_image(self.var_H[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\hr", "%05i_%02i.png" % (step, i))) + utils.save_image(self.var_L[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\lr", "%05i_%02i.png" % (step, i))) + + # G for p in self.netD.parameters(): p.requires_grad = False @@ -161,7 +176,8 @@ class SRGANModel(BaseModel): self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 l_g_total += l_g_gan - l_g_total.backward() + with amp.scale_loss(l_g_total, self.optimizer_G, loss_id=0) as l_g_total_scaled: + l_g_total_scaled.backward() self.optimizer_G.step() # D @@ -178,7 +194,8 @@ class SRGANModel(BaseModel): # fake pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G l_d_fake = self.cri_gan(pred_d_fake, False) - l_d_fake.backward() + with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: + l_d_fake_scaled.backward() elif self.opt['train']['gan_type'] == 'ragan': # pred_d_real = self.netD(self.var_ref) # pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G @@ -189,10 +206,12 @@ class SRGANModel(BaseModel): pred_d_fake = self.netD(self.fake_H.detach()).detach() pred_d_real = self.netD(self.var_ref) l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 - l_d_real.backward() + with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: + l_d_real_scaled.backward() pred_d_fake = self.netD(self.fake_H.detach()) l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 - l_d_fake.backward() + with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: + l_d_fake_scaled.backward() self.optimizer_D.step() # set log diff --git a/codes/models/base_model.py b/codes/models/base_model.py index 8a5d2225..7d9dfdb1 100644 --- a/codes/models/base_model.py +++ b/codes/models/base_model.py @@ -9,6 +9,7 @@ class BaseModel(): def __init__(self, opt): self.opt = opt self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu') + self.amp_level = 'O0' if opt['amp_opt_level'] is None else opt['amp_opt_level'] self.is_train = opt['is_train'] self.schedulers = [] self.optimizers = []