Enable megabatching

This commit is contained in:
James Betker 2020-05-02 17:46:59 -06:00
parent 61d3040cf5
commit 8341bf7646
2 changed files with 100 additions and 84 deletions

View File

@ -31,6 +31,9 @@ class SRGANModel(BaseModel):
# define losses, optimizer and scheduler
if self.is_train:
self.mega_batch_factor = train_opt['mega_batch_factor']
if self.mega_batch_factor is None:
self.mega_batch_factor = 1
# G pixel loss
if train_opt['pixel_weight'] > 0:
l_pix_type = train_opt['pixel_criterion']
@ -138,12 +141,12 @@ class SRGANModel(BaseModel):
self.load() # load G and D if needed
def feed_data(self, data, need_GT=True):
self.var_L = data['LQ'].to(self.device) # LQ
self.var_L = torch.chunk(data['LQ'], chunks=self.mega_batch_factor, dim=0) # LQ
if need_GT:
self.var_H = data['GT'].to(self.device) # GT
self.var_H = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
input_ref = data['ref'] if 'ref' in data else data['GT']
self.var_ref = input_ref.to(self.device)
self.pix = data['PIX'].to(self.device)
self.var_ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)]
self.pix = [t.to(self.device) for t in torch.chunk(data['PIX'], chunks=self.mega_batch_factor, dim=0)]
def optimize_parameters(self, step):
# G
@ -152,25 +155,23 @@ class SRGANModel(BaseModel):
if step > self.D_init_iters:
self.optimizer_G.zero_grad()
self.fake_H = self.netG(self.var_L)
else:
self.fake_H = self.pix
if step % 50 == 0:
for i in range(self.var_L.shape[0]):
utils.save_image(self.var_H[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\hr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.var_L[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\lr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.pix[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\pix", "%05i_%02i.png" % (step, i)))
utils.save_image(self.fake_H[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\gen", "%05i_%02i.png" % (step, i)))
self.fake_H = []
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
if step > self.D_init_iters:
fake_H = self.netG(var_L)
else:
fake_H = pix
self.fake_H.append(fake_H.detach())
l_g_total = 0
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
if self.cri_pix: # pixel loss
l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.pix)
l_g_pix = self.l_pix_w * self.cri_pix(fake_H, pix)
l_g_total += l_g_pix
if self.cri_fea: # feature loss
real_fea = self.netF(self.pix).detach()
fake_fea = self.netF(self.fake_H)
real_fea = self.netF(pix).detach()
fake_fea = self.netF(fake_H)
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
l_g_total += l_g_fea
@ -180,11 +181,11 @@ class SRGANModel(BaseModel):
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay)
if self.opt['train']['gan_type'] == 'gan':
pred_g_fake = self.netD(self.fake_H)
pred_g_fake = self.netD(fake_H)
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
elif self.opt['train']['gan_type'] == 'ragan':
pred_d_real = self.netD(self.var_ref).detach()
pred_g_fake = self.netD(self.fake_H)
pred_d_real = self.netD(var_ref).detach()
pred_g_fake = self.netD(fake_H)
l_g_gan = self.l_gan_w * (
self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
@ -199,37 +200,50 @@ class SRGANModel(BaseModel):
p.requires_grad = True
self.optimizer_D.zero_grad()
for var_L, var_H, var_ref, pix, fake_H in zip(self.var_L, self.var_H, self.var_ref, self.pix, self.fake_H):
if self.opt['train']['gan_type'] == 'gan':
# need to forward and backward separately, since batch norm statistics differ
# real
pred_d_real = self.netD(self.var_ref)
pred_d_real = self.netD(var_ref)
l_d_real = self.cri_gan(pred_d_real, True)
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
l_d_real_scaled.backward()
# fake
pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
pred_d_fake = self.netD(fake_H.detach()) # detach to avoid BP to G
l_d_fake = self.cri_gan(pred_d_fake, False)
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
l_d_fake_scaled.backward()
elif self.opt['train']['gan_type'] == 'ragan':
# pred_d_real = self.netD(self.var_ref)
# pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
# pred_d_real = self.netD(var_ref)
# pred_d_fake = self.netD(fake_H.detach()) # detach to avoid BP to G
# l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
# l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
# l_d_total = (l_d_real + l_d_fake) / 2
# l_d_total.backward()
pred_d_fake = self.netD(self.fake_H.detach()).detach()
pred_d_real = self.netD(self.var_ref)
pred_d_fake = self.netD(fake_H.detach()).detach()
pred_d_real = self.netD(var_ref)
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
l_d_real_scaled.backward()
pred_d_fake = self.netD(self.fake_H.detach())
pred_d_fake = self.netD(fake_H.detach())
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
l_d_fake_scaled.backward()
self.optimizer_D.step()
# set log
# Log sample images from first microbatch.
if step % 50 == 0:
os.makedirs("temp/hr", exist_ok=True)
os.makedirs("temp/lr", exist_ok=True)
os.makedirs("temp/gen", exist_ok=True)
os.makedirs("temp/pix", exist_ok=True)
for i in range(self.var_L[0].shape[0]):
utils.save_image(self.var_H[0][i].cpu().detach(), os.path.join("temp/hr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.var_L[0][i].cpu().detach(), os.path.join("temp/lr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.pix[0][i].cpu().detach(), os.path.join("temp/pix", "%05i_%02i.png" % (step, i)))
utils.save_image(self.fake_H[0][i].cpu().detach(), os.path.join("temp/gen", "%05i_%02i.png" % (step, i)))
# set log TODO(handle mega-batches?)
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
if self.cri_pix:
self.log_dict['l_g_pix'] = l_g_pix.item()
@ -245,7 +259,7 @@ class SRGANModel(BaseModel):
def test(self):
self.netG.eval()
with torch.no_grad():
self.fake_H = self.netG(self.var_L)
self.fake_H = [self.netG(self.var_L[0])]
self.netG.train()
def get_current_log(self):
@ -253,10 +267,10 @@ class SRGANModel(BaseModel):
def get_current_visuals(self, need_GT=True):
out_dict = OrderedDict()
out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
out_dict['LQ'] = self.var_L[0].detach()[0].float().cpu()
out_dict['rlt'] = self.fake_H[0].detach()[0].float().cpu()
if need_GT:
out_dict['GT'] = self.var_H.detach()[0].float().cpu()
out_dict['GT'] = self.var_H[0].detach()[0].float().cpu()
return out_dict
def print_network(self):

View File

@ -5,7 +5,7 @@ model: srgan
distortion: sr
scale: 4
gpu_ids: [0]
amp_opt_level: O1
amp_opt_level: O0
#### datasets
datasets:
@ -14,10 +14,10 @@ datasets:
mode: LQGT
dataroot_GT: K:\4k6k\4k_closeup\hr
dataroot_LQ: K:\4k6k\4k_closeup\lr_corrupted
doCrop: false
use_shuffle: true
n_workers: 12 # per GPU
batch_size: 12
batch_size: 64
target_size: 256
color: RGB
val:
@ -40,17 +40,18 @@ network_D:
#### path
path:
pretrain_model_G: ~
pretrain_model_G: ../experiments/blacked_fix_and_upconv_gan_only/models/7000_G.pth
pretrain_model_D: ../experiments/blacked_fix_and_upconv_gan_only/models/7000_D.pth
strict_load: true
resume_state: ~
#### training settings: learning rate scheme, loss
train:
lr_G: !!float 1e-4
lr_G: !!float 5e-5
weight_decay_G: 0
beta1_G: 0.9
beta2_G: 0.99
lr_D: !!float 2e-4
lr_D: !!float 8e-5
weight_decay_D: 0
beta1_D: 0.9
beta2_D: 0.99
@ -58,21 +59,22 @@ train:
niter: 400000
warmup_iter: -1 # no warm up
lr_steps: [20000, 40000, 60000, 80000]
lr_steps: [5000, 20000, 40000, 60000]
lr_gamma: 0.5
mega_batch_factor: 8
pixel_criterion: l1
pixel_weight: !!float 1e-2
feature_criterion: l1
feature_weight: 1
feature_weight_decay: .98
feature_weight: 0
feature_weight_decay: .9
feature_weight_decay_steps: 500
feature_weight_minimum: .1
gan_type: gan # gan | ragan
gan_weight: !!float 5e-3
gan_weight: 1
D_update_ratio: 1
D_init_iters: 0
D_init_iters: -1
manual_seed: 10
val_freq: !!float 5e2