Enable AMP optimizations & write sample train images to folder.

This commit is contained in:
James Betker 2020-04-21 16:28:06 -06:00
parent 9fc556be35
commit 4f6d3f0dfb
2 changed files with 24 additions and 4 deletions

View File

@ -7,6 +7,10 @@ import models.networks as networks
import models.lr_scheduler as lr_scheduler
from .base_model import BaseModel
from models.loss import GANLoss
from apex import amp
import torchvision.utils as utils
import os
logger = logging.getLogger('base')
@ -101,6 +105,10 @@ class SRGANModel(BaseModel):
betas=(train_opt['beta1_D'], train_opt['beta2_D']))
self.optimizers.append(self.optimizer_D)
# AMP
[self.netG, self.netD], [self.optimizer_G, self.optimizer_D] = \
amp.initialize([self.netG, self.netD], [self.optimizer_G, self.optimizer_D], opt_level=self.amp_level, num_losses=3)
# schedulers
if train_opt['lr_scheme'] == 'MultiStepLR':
for optimizer in self.optimizers:
@ -132,6 +140,13 @@ class SRGANModel(BaseModel):
self.var_ref = input_ref.to(self.device)
def optimize_parameters(self, step):
if step % 50 == 0:
for i in range(self.var_L.shape[0]):
utils.save_image(self.var_H[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\hr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.var_L[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\lr", "%05i_%02i.png" % (step, i)))
# G
for p in self.netD.parameters():
p.requires_grad = False
@ -161,7 +176,8 @@ class SRGANModel(BaseModel):
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
l_g_total += l_g_gan
l_g_total.backward()
with amp.scale_loss(l_g_total, self.optimizer_G, loss_id=0) as l_g_total_scaled:
l_g_total_scaled.backward()
self.optimizer_G.step()
# D
@ -178,7 +194,8 @@ class SRGANModel(BaseModel):
# fake
pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
l_d_fake = self.cri_gan(pred_d_fake, False)
l_d_fake.backward()
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
l_d_fake_scaled.backward()
elif self.opt['train']['gan_type'] == 'ragan':
# pred_d_real = self.netD(self.var_ref)
# pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
@ -189,10 +206,12 @@ class SRGANModel(BaseModel):
pred_d_fake = self.netD(self.fake_H.detach()).detach()
pred_d_real = self.netD(self.var_ref)
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5
l_d_real.backward()
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
l_d_real_scaled.backward()
pred_d_fake = self.netD(self.fake_H.detach())
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5
l_d_fake.backward()
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
l_d_fake_scaled.backward()
self.optimizer_D.step()
# set log

View File

@ -9,6 +9,7 @@ class BaseModel():
def __init__(self, opt):
self.opt = opt
self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
self.amp_level = 'O0' if opt['amp_opt_level'] is None else opt['amp_opt_level']
self.is_train = opt['is_train']
self.schedulers = []
self.optimizers = []