Enable megabatching

This commit is contained in:
James Betker 2020-05-02 17:46:59 -06:00
parent 61d3040cf5
commit 8341bf7646
2 changed files with 100 additions and 84 deletions

View File

@ -31,6 +31,9 @@ class SRGANModel(BaseModel):
# define losses, optimizer and scheduler # define losses, optimizer and scheduler
if self.is_train: if self.is_train:
self.mega_batch_factor = train_opt['mega_batch_factor']
if self.mega_batch_factor is None:
self.mega_batch_factor = 1
# G pixel loss # G pixel loss
if train_opt['pixel_weight'] > 0: if train_opt['pixel_weight'] > 0:
l_pix_type = train_opt['pixel_criterion'] l_pix_type = train_opt['pixel_criterion']
@ -138,12 +141,12 @@ class SRGANModel(BaseModel):
self.load() # load G and D if needed self.load() # load G and D if needed
def feed_data(self, data, need_GT=True): def feed_data(self, data, need_GT=True):
self.var_L = data['LQ'].to(self.device) # LQ self.var_L = torch.chunk(data['LQ'], chunks=self.mega_batch_factor, dim=0) # LQ
if need_GT: if need_GT:
self.var_H = data['GT'].to(self.device) # GT self.var_H = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
input_ref = data['ref'] if 'ref' in data else data['GT'] input_ref = data['ref'] if 'ref' in data else data['GT']
self.var_ref = input_ref.to(self.device) self.var_ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)]
self.pix = data['PIX'].to(self.device) self.pix = [t.to(self.device) for t in torch.chunk(data['PIX'], chunks=self.mega_batch_factor, dim=0)]
def optimize_parameters(self, step): def optimize_parameters(self, step):
# G # G
@ -152,25 +155,23 @@ class SRGANModel(BaseModel):
if step > self.D_init_iters: if step > self.D_init_iters:
self.optimizer_G.zero_grad() self.optimizer_G.zero_grad()
self.fake_H = self.netG(self.var_L)
else:
self.fake_H = self.pix
if step % 50 == 0: self.fake_H = []
for i in range(self.var_L.shape[0]): for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
utils.save_image(self.var_H[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\hr", "%05i_%02i.png" % (step, i))) if step > self.D_init_iters:
utils.save_image(self.var_L[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\lr", "%05i_%02i.png" % (step, i))) fake_H = self.netG(var_L)
utils.save_image(self.pix[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\pix", "%05i_%02i.png" % (step, i))) else:
utils.save_image(self.fake_H[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\gen", "%05i_%02i.png" % (step, i))) fake_H = pix
self.fake_H.append(fake_H.detach())
l_g_total = 0 l_g_total = 0
if step % self.D_update_ratio == 0 and step > self.D_init_iters: if step % self.D_update_ratio == 0 and step > self.D_init_iters:
if self.cri_pix: # pixel loss if self.cri_pix: # pixel loss
l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.pix) l_g_pix = self.l_pix_w * self.cri_pix(fake_H, pix)
l_g_total += l_g_pix l_g_total += l_g_pix
if self.cri_fea: # feature loss if self.cri_fea: # feature loss
real_fea = self.netF(self.pix).detach() real_fea = self.netF(pix).detach()
fake_fea = self.netF(self.fake_H) fake_fea = self.netF(fake_H)
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
l_g_total += l_g_fea l_g_total += l_g_fea
@ -180,11 +181,11 @@ class SRGANModel(BaseModel):
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay) self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay)
if self.opt['train']['gan_type'] == 'gan': if self.opt['train']['gan_type'] == 'gan':
pred_g_fake = self.netD(self.fake_H) pred_g_fake = self.netD(fake_H)
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
elif self.opt['train']['gan_type'] == 'ragan': elif self.opt['train']['gan_type'] == 'ragan':
pred_d_real = self.netD(self.var_ref).detach() pred_d_real = self.netD(var_ref).detach()
pred_g_fake = self.netD(self.fake_H) pred_g_fake = self.netD(fake_H)
l_g_gan = self.l_gan_w * ( l_g_gan = self.l_gan_w * (
self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
@ -199,37 +200,50 @@ class SRGANModel(BaseModel):
p.requires_grad = True p.requires_grad = True
self.optimizer_D.zero_grad() self.optimizer_D.zero_grad()
for var_L, var_H, var_ref, pix, fake_H in zip(self.var_L, self.var_H, self.var_ref, self.pix, self.fake_H):
if self.opt['train']['gan_type'] == 'gan': if self.opt['train']['gan_type'] == 'gan':
# need to forward and backward separately, since batch norm statistics differ # need to forward and backward separately, since batch norm statistics differ
# real # real
pred_d_real = self.netD(self.var_ref) pred_d_real = self.netD(var_ref)
l_d_real = self.cri_gan(pred_d_real, True) l_d_real = self.cri_gan(pred_d_real, True)
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
l_d_real_scaled.backward() l_d_real_scaled.backward()
# fake # fake
pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G pred_d_fake = self.netD(fake_H.detach()) # detach to avoid BP to G
l_d_fake = self.cri_gan(pred_d_fake, False) l_d_fake = self.cri_gan(pred_d_fake, False)
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
l_d_fake_scaled.backward() l_d_fake_scaled.backward()
elif self.opt['train']['gan_type'] == 'ragan': elif self.opt['train']['gan_type'] == 'ragan':
# pred_d_real = self.netD(self.var_ref) # pred_d_real = self.netD(var_ref)
# pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G # pred_d_fake = self.netD(fake_H.detach()) # detach to avoid BP to G
# l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) # l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
# l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) # l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
# l_d_total = (l_d_real + l_d_fake) / 2 # l_d_total = (l_d_real + l_d_fake) / 2
# l_d_total.backward() # l_d_total.backward()
pred_d_fake = self.netD(self.fake_H.detach()).detach() pred_d_fake = self.netD(fake_H.detach()).detach()
pred_d_real = self.netD(self.var_ref) pred_d_real = self.netD(var_ref)
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5 l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
l_d_real_scaled.backward() l_d_real_scaled.backward()
pred_d_fake = self.netD(self.fake_H.detach()) pred_d_fake = self.netD(fake_H.detach())
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5 l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
l_d_fake_scaled.backward() l_d_fake_scaled.backward()
self.optimizer_D.step() self.optimizer_D.step()
# set log # Log sample images from first microbatch.
if step % 50 == 0:
os.makedirs("temp/hr", exist_ok=True)
os.makedirs("temp/lr", exist_ok=True)
os.makedirs("temp/gen", exist_ok=True)
os.makedirs("temp/pix", exist_ok=True)
for i in range(self.var_L[0].shape[0]):
utils.save_image(self.var_H[0][i].cpu().detach(), os.path.join("temp/hr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.var_L[0][i].cpu().detach(), os.path.join("temp/lr", "%05i_%02i.png" % (step, i)))
utils.save_image(self.pix[0][i].cpu().detach(), os.path.join("temp/pix", "%05i_%02i.png" % (step, i)))
utils.save_image(self.fake_H[0][i].cpu().detach(), os.path.join("temp/gen", "%05i_%02i.png" % (step, i)))
# set log TODO(handle mega-batches?)
if step % self.D_update_ratio == 0 and step > self.D_init_iters: if step % self.D_update_ratio == 0 and step > self.D_init_iters:
if self.cri_pix: if self.cri_pix:
self.log_dict['l_g_pix'] = l_g_pix.item() self.log_dict['l_g_pix'] = l_g_pix.item()
@ -245,7 +259,7 @@ class SRGANModel(BaseModel):
def test(self): def test(self):
self.netG.eval() self.netG.eval()
with torch.no_grad(): with torch.no_grad():
self.fake_H = self.netG(self.var_L) self.fake_H = [self.netG(self.var_L[0])]
self.netG.train() self.netG.train()
def get_current_log(self): def get_current_log(self):
@ -253,10 +267,10 @@ class SRGANModel(BaseModel):
def get_current_visuals(self, need_GT=True): def get_current_visuals(self, need_GT=True):
out_dict = OrderedDict() out_dict = OrderedDict()
out_dict['LQ'] = self.var_L.detach()[0].float().cpu() out_dict['LQ'] = self.var_L[0].detach()[0].float().cpu()
out_dict['rlt'] = self.fake_H.detach()[0].float().cpu() out_dict['rlt'] = self.fake_H[0].detach()[0].float().cpu()
if need_GT: if need_GT:
out_dict['GT'] = self.var_H.detach()[0].float().cpu() out_dict['GT'] = self.var_H[0].detach()[0].float().cpu()
return out_dict return out_dict
def print_network(self): def print_network(self):

View File

@ -5,7 +5,7 @@ model: srgan
distortion: sr distortion: sr
scale: 4 scale: 4
gpu_ids: [0] gpu_ids: [0]
amp_opt_level: O1 amp_opt_level: O0
#### datasets #### datasets
datasets: datasets:
@ -14,10 +14,10 @@ datasets:
mode: LQGT mode: LQGT
dataroot_GT: K:\4k6k\4k_closeup\hr dataroot_GT: K:\4k6k\4k_closeup\hr
dataroot_LQ: K:\4k6k\4k_closeup\lr_corrupted dataroot_LQ: K:\4k6k\4k_closeup\lr_corrupted
doCrop: false
use_shuffle: true use_shuffle: true
n_workers: 12 # per GPU n_workers: 12 # per GPU
batch_size: 12 batch_size: 64
target_size: 256 target_size: 256
color: RGB color: RGB
val: val:
@ -40,17 +40,18 @@ network_D:
#### path #### path
path: path:
pretrain_model_G: ~ pretrain_model_G: ../experiments/blacked_fix_and_upconv_gan_only/models/7000_G.pth
pretrain_model_D: ../experiments/blacked_fix_and_upconv_gan_only/models/7000_D.pth
strict_load: true strict_load: true
resume_state: ~ resume_state: ~
#### training settings: learning rate scheme, loss #### training settings: learning rate scheme, loss
train: train:
lr_G: !!float 1e-4 lr_G: !!float 5e-5
weight_decay_G: 0 weight_decay_G: 0
beta1_G: 0.9 beta1_G: 0.9
beta2_G: 0.99 beta2_G: 0.99
lr_D: !!float 2e-4 lr_D: !!float 8e-5
weight_decay_D: 0 weight_decay_D: 0
beta1_D: 0.9 beta1_D: 0.9
beta2_D: 0.99 beta2_D: 0.99
@ -58,21 +59,22 @@ train:
niter: 400000 niter: 400000
warmup_iter: -1 # no warm up warmup_iter: -1 # no warm up
lr_steps: [20000, 40000, 60000, 80000] lr_steps: [5000, 20000, 40000, 60000]
lr_gamma: 0.5 lr_gamma: 0.5
mega_batch_factor: 8
pixel_criterion: l1 pixel_criterion: l1
pixel_weight: !!float 1e-2 pixel_weight: !!float 1e-2
feature_criterion: l1 feature_criterion: l1
feature_weight: 1 feature_weight: 0
feature_weight_decay: .98 feature_weight_decay: .9
feature_weight_decay_steps: 500 feature_weight_decay_steps: 500
feature_weight_minimum: .1 feature_weight_minimum: .1
gan_type: gan # gan | ragan gan_type: gan # gan | ragan
gan_weight: !!float 5e-3 gan_weight: 1
D_update_ratio: 1 D_update_ratio: 1
D_init_iters: 0 D_init_iters: -1
manual_seed: 10 manual_seed: 10
val_freq: !!float 5e2 val_freq: !!float 5e2