forked from mrq/DL-Art-School
Enable AMP optimizations & write sample train images to folder.
This commit is contained in:
parent
9fc556be35
commit
4f6d3f0dfb
|
@ -7,6 +7,10 @@ import models.networks as networks
|
|||
import models.lr_scheduler as lr_scheduler
|
||||
from .base_model import BaseModel
|
||||
from models.loss import GANLoss
|
||||
from apex import amp
|
||||
|
||||
import torchvision.utils as utils
|
||||
import os
|
||||
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
|
@ -101,6 +105,10 @@ class SRGANModel(BaseModel):
|
|||
betas=(train_opt['beta1_D'], train_opt['beta2_D']))
|
||||
self.optimizers.append(self.optimizer_D)
|
||||
|
||||
# AMP
|
||||
[self.netG, self.netD], [self.optimizer_G, self.optimizer_D] = \
|
||||
amp.initialize([self.netG, self.netD], [self.optimizer_G, self.optimizer_D], opt_level=self.amp_level, num_losses=3)
|
||||
|
||||
# schedulers
|
||||
if train_opt['lr_scheme'] == 'MultiStepLR':
|
||||
for optimizer in self.optimizers:
|
||||
|
@ -132,6 +140,13 @@ class SRGANModel(BaseModel):
|
|||
self.var_ref = input_ref.to(self.device)
|
||||
|
||||
def optimize_parameters(self, step):
|
||||
|
||||
if step % 50 == 0:
|
||||
for i in range(self.var_L.shape[0]):
|
||||
utils.save_image(self.var_H[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\hr", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.var_L[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\lr", "%05i_%02i.png" % (step, i)))
|
||||
|
||||
|
||||
# G
|
||||
for p in self.netD.parameters():
|
||||
p.requires_grad = False
|
||||
|
@ -161,7 +176,8 @@ class SRGANModel(BaseModel):
|
|||
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
||||
l_g_total += l_g_gan
|
||||
|
||||
l_g_total.backward()
|
||||
with amp.scale_loss(l_g_total, self.optimizer_G, loss_id=0) as l_g_total_scaled:
|
||||
l_g_total_scaled.backward()
|
||||
self.optimizer_G.step()
|
||||
|
||||
# D
|
||||
|
@ -178,7 +194,8 @@ class SRGANModel(BaseModel):
|
|||
# fake
|
||||
pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
|
||||
l_d_fake = self.cri_gan(pred_d_fake, False)
|
||||
l_d_fake.backward()
|
||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||
l_d_fake_scaled.backward()
|
||||
elif self.opt['train']['gan_type'] == 'ragan':
|
||||
# pred_d_real = self.netD(self.var_ref)
|
||||
# pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
|
||||
|
@ -189,10 +206,12 @@ class SRGANModel(BaseModel):
|
|||
pred_d_fake = self.netD(self.fake_H.detach()).detach()
|
||||
pred_d_real = self.netD(self.var_ref)
|
||||
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5
|
||||
l_d_real.backward()
|
||||
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
||||
l_d_real_scaled.backward()
|
||||
pred_d_fake = self.netD(self.fake_H.detach())
|
||||
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5
|
||||
l_d_fake.backward()
|
||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||
l_d_fake_scaled.backward()
|
||||
self.optimizer_D.step()
|
||||
|
||||
# set log
|
||||
|
|
|
@ -9,6 +9,7 @@ class BaseModel():
|
|||
def __init__(self, opt):
|
||||
self.opt = opt
|
||||
self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
|
||||
self.amp_level = 'O0' if opt['amp_opt_level'] is None else opt['amp_opt_level']
|
||||
self.is_train = opt['is_train']
|
||||
self.schedulers = []
|
||||
self.optimizers = []
|
||||
|
|
Loading…
Reference in New Issue
Block a user