Enable megabatching
This commit is contained in:
parent
61d3040cf5
commit
8341bf7646
|
@ -31,6 +31,9 @@ class SRGANModel(BaseModel):
|
|||
|
||||
# define losses, optimizer and scheduler
|
||||
if self.is_train:
|
||||
self.mega_batch_factor = train_opt['mega_batch_factor']
|
||||
if self.mega_batch_factor is None:
|
||||
self.mega_batch_factor = 1
|
||||
# G pixel loss
|
||||
if train_opt['pixel_weight'] > 0:
|
||||
l_pix_type = train_opt['pixel_criterion']
|
||||
|
@ -138,12 +141,12 @@ class SRGANModel(BaseModel):
|
|||
self.load() # load G and D if needed
|
||||
|
||||
def feed_data(self, data, need_GT=True):
|
||||
self.var_L = data['LQ'].to(self.device) # LQ
|
||||
self.var_L = torch.chunk(data['LQ'], chunks=self.mega_batch_factor, dim=0) # LQ
|
||||
if need_GT:
|
||||
self.var_H = data['GT'].to(self.device) # GT
|
||||
self.var_H = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
|
||||
input_ref = data['ref'] if 'ref' in data else data['GT']
|
||||
self.var_ref = input_ref.to(self.device)
|
||||
self.pix = data['PIX'].to(self.device)
|
||||
self.var_ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)]
|
||||
self.pix = [t.to(self.device) for t in torch.chunk(data['PIX'], chunks=self.mega_batch_factor, dim=0)]
|
||||
|
||||
def optimize_parameters(self, step):
|
||||
# G
|
||||
|
@ -152,84 +155,95 @@ class SRGANModel(BaseModel):
|
|||
|
||||
if step > self.D_init_iters:
|
||||
self.optimizer_G.zero_grad()
|
||||
self.fake_H = self.netG(self.var_L)
|
||||
else:
|
||||
self.fake_H = self.pix
|
||||
|
||||
if step % 50 == 0:
|
||||
for i in range(self.var_L.shape[0]):
|
||||
utils.save_image(self.var_H[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\hr", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.var_L[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\lr", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.pix[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\pix", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.fake_H[i].cpu().detach(), os.path.join("E:\\4k6k\\temp\\gen", "%05i_%02i.png" % (step, i)))
|
||||
self.fake_H = []
|
||||
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
|
||||
if step > self.D_init_iters:
|
||||
fake_H = self.netG(var_L)
|
||||
else:
|
||||
fake_H = pix
|
||||
self.fake_H.append(fake_H.detach())
|
||||
|
||||
l_g_total = 0
|
||||
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||
if self.cri_pix: # pixel loss
|
||||
l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.pix)
|
||||
l_g_total += l_g_pix
|
||||
if self.cri_fea: # feature loss
|
||||
real_fea = self.netF(self.pix).detach()
|
||||
fake_fea = self.netF(self.fake_H)
|
||||
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
|
||||
l_g_total += l_g_fea
|
||||
l_g_total = 0
|
||||
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||
if self.cri_pix: # pixel loss
|
||||
l_g_pix = self.l_pix_w * self.cri_pix(fake_H, pix)
|
||||
l_g_total += l_g_pix
|
||||
if self.cri_fea: # feature loss
|
||||
real_fea = self.netF(pix).detach()
|
||||
fake_fea = self.netF(fake_H)
|
||||
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
|
||||
l_g_total += l_g_fea
|
||||
|
||||
# Decay the influence of the feature loss. As the model trains, the GAN will play a stronger role
|
||||
# in the resultant image.
|
||||
if step % self.l_fea_w_decay_steps == 0:
|
||||
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay)
|
||||
# Decay the influence of the feature loss. As the model trains, the GAN will play a stronger role
|
||||
# in the resultant image.
|
||||
if step % self.l_fea_w_decay_steps == 0:
|
||||
self.l_fea_w = max(self.l_fea_w_minimum, self.l_fea_w * self.l_fea_w_decay)
|
||||
|
||||
if self.opt['train']['gan_type'] == 'gan':
|
||||
pred_g_fake = self.netD(self.fake_H)
|
||||
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
|
||||
elif self.opt['train']['gan_type'] == 'ragan':
|
||||
pred_d_real = self.netD(self.var_ref).detach()
|
||||
pred_g_fake = self.netD(self.fake_H)
|
||||
l_g_gan = self.l_gan_w * (
|
||||
self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
|
||||
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
||||
l_g_total += l_g_gan
|
||||
if self.opt['train']['gan_type'] == 'gan':
|
||||
pred_g_fake = self.netD(fake_H)
|
||||
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
|
||||
elif self.opt['train']['gan_type'] == 'ragan':
|
||||
pred_d_real = self.netD(var_ref).detach()
|
||||
pred_g_fake = self.netD(fake_H)
|
||||
l_g_gan = self.l_gan_w * (
|
||||
self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
|
||||
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
||||
l_g_total += l_g_gan
|
||||
|
||||
with amp.scale_loss(l_g_total, self.optimizer_G, loss_id=0) as l_g_total_scaled:
|
||||
l_g_total_scaled.backward()
|
||||
self.optimizer_G.step()
|
||||
with amp.scale_loss(l_g_total, self.optimizer_G, loss_id=0) as l_g_total_scaled:
|
||||
l_g_total_scaled.backward()
|
||||
self.optimizer_G.step()
|
||||
|
||||
# D
|
||||
for p in self.netD.parameters():
|
||||
p.requires_grad = True
|
||||
|
||||
self.optimizer_D.zero_grad()
|
||||
if self.opt['train']['gan_type'] == 'gan':
|
||||
# need to forward and backward separately, since batch norm statistics differ
|
||||
# real
|
||||
pred_d_real = self.netD(self.var_ref)
|
||||
l_d_real = self.cri_gan(pred_d_real, True)
|
||||
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
||||
l_d_real_scaled.backward()
|
||||
# fake
|
||||
pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
|
||||
l_d_fake = self.cri_gan(pred_d_fake, False)
|
||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||
l_d_fake_scaled.backward()
|
||||
elif self.opt['train']['gan_type'] == 'ragan':
|
||||
# pred_d_real = self.netD(self.var_ref)
|
||||
# pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
|
||||
# l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
|
||||
# l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
|
||||
# l_d_total = (l_d_real + l_d_fake) / 2
|
||||
# l_d_total.backward()
|
||||
pred_d_fake = self.netD(self.fake_H.detach()).detach()
|
||||
pred_d_real = self.netD(self.var_ref)
|
||||
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5
|
||||
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
||||
l_d_real_scaled.backward()
|
||||
pred_d_fake = self.netD(self.fake_H.detach())
|
||||
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5
|
||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||
l_d_fake_scaled.backward()
|
||||
for var_L, var_H, var_ref, pix, fake_H in zip(self.var_L, self.var_H, self.var_ref, self.pix, self.fake_H):
|
||||
if self.opt['train']['gan_type'] == 'gan':
|
||||
# need to forward and backward separately, since batch norm statistics differ
|
||||
# real
|
||||
pred_d_real = self.netD(var_ref)
|
||||
l_d_real = self.cri_gan(pred_d_real, True)
|
||||
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
||||
l_d_real_scaled.backward()
|
||||
# fake
|
||||
pred_d_fake = self.netD(fake_H.detach()) # detach to avoid BP to G
|
||||
l_d_fake = self.cri_gan(pred_d_fake, False)
|
||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||
l_d_fake_scaled.backward()
|
||||
elif self.opt['train']['gan_type'] == 'ragan':
|
||||
# pred_d_real = self.netD(var_ref)
|
||||
# pred_d_fake = self.netD(fake_H.detach()) # detach to avoid BP to G
|
||||
# l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
|
||||
# l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
|
||||
# l_d_total = (l_d_real + l_d_fake) / 2
|
||||
# l_d_total.backward()
|
||||
pred_d_fake = self.netD(fake_H.detach()).detach()
|
||||
pred_d_real = self.netD(var_ref)
|
||||
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) * 0.5
|
||||
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
||||
l_d_real_scaled.backward()
|
||||
pred_d_fake = self.netD(fake_H.detach())
|
||||
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5
|
||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||
l_d_fake_scaled.backward()
|
||||
self.optimizer_D.step()
|
||||
|
||||
# set log
|
||||
# Log sample images from first microbatch.
|
||||
if step % 50 == 0:
|
||||
os.makedirs("temp/hr", exist_ok=True)
|
||||
os.makedirs("temp/lr", exist_ok=True)
|
||||
os.makedirs("temp/gen", exist_ok=True)
|
||||
os.makedirs("temp/pix", exist_ok=True)
|
||||
for i in range(self.var_L[0].shape[0]):
|
||||
utils.save_image(self.var_H[0][i].cpu().detach(), os.path.join("temp/hr", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.var_L[0][i].cpu().detach(), os.path.join("temp/lr", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.pix[0][i].cpu().detach(), os.path.join("temp/pix", "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(self.fake_H[0][i].cpu().detach(), os.path.join("temp/gen", "%05i_%02i.png" % (step, i)))
|
||||
|
||||
# set log TODO(handle mega-batches?)
|
||||
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
||||
if self.cri_pix:
|
||||
self.log_dict['l_g_pix'] = l_g_pix.item()
|
||||
|
@ -245,7 +259,7 @@ class SRGANModel(BaseModel):
|
|||
def test(self):
|
||||
self.netG.eval()
|
||||
with torch.no_grad():
|
||||
self.fake_H = self.netG(self.var_L)
|
||||
self.fake_H = [self.netG(self.var_L[0])]
|
||||
self.netG.train()
|
||||
|
||||
def get_current_log(self):
|
||||
|
@ -253,10 +267,10 @@ class SRGANModel(BaseModel):
|
|||
|
||||
def get_current_visuals(self, need_GT=True):
|
||||
out_dict = OrderedDict()
|
||||
out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
|
||||
out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
|
||||
out_dict['LQ'] = self.var_L[0].detach()[0].float().cpu()
|
||||
out_dict['rlt'] = self.fake_H[0].detach()[0].float().cpu()
|
||||
if need_GT:
|
||||
out_dict['GT'] = self.var_H.detach()[0].float().cpu()
|
||||
out_dict['GT'] = self.var_H[0].detach()[0].float().cpu()
|
||||
return out_dict
|
||||
|
||||
def print_network(self):
|
||||
|
|
|
@ -5,7 +5,7 @@ model: srgan
|
|||
distortion: sr
|
||||
scale: 4
|
||||
gpu_ids: [0]
|
||||
amp_opt_level: O1
|
||||
amp_opt_level: O0
|
||||
|
||||
#### datasets
|
||||
datasets:
|
||||
|
@ -14,10 +14,10 @@ datasets:
|
|||
mode: LQGT
|
||||
dataroot_GT: K:\4k6k\4k_closeup\hr
|
||||
dataroot_LQ: K:\4k6k\4k_closeup\lr_corrupted
|
||||
|
||||
doCrop: false
|
||||
use_shuffle: true
|
||||
n_workers: 12 # per GPU
|
||||
batch_size: 12
|
||||
batch_size: 64
|
||||
target_size: 256
|
||||
color: RGB
|
||||
val:
|
||||
|
@ -40,17 +40,18 @@ network_D:
|
|||
|
||||
#### path
|
||||
path:
|
||||
pretrain_model_G: ~
|
||||
pretrain_model_G: ../experiments/blacked_fix_and_upconv_gan_only/models/7000_G.pth
|
||||
pretrain_model_D: ../experiments/blacked_fix_and_upconv_gan_only/models/7000_D.pth
|
||||
strict_load: true
|
||||
resume_state: ~
|
||||
|
||||
#### training settings: learning rate scheme, loss
|
||||
train:
|
||||
lr_G: !!float 1e-4
|
||||
lr_G: !!float 5e-5
|
||||
weight_decay_G: 0
|
||||
beta1_G: 0.9
|
||||
beta2_G: 0.99
|
||||
lr_D: !!float 2e-4
|
||||
lr_D: !!float 8e-5
|
||||
weight_decay_D: 0
|
||||
beta1_D: 0.9
|
||||
beta2_D: 0.99
|
||||
|
@ -58,21 +59,22 @@ train:
|
|||
|
||||
niter: 400000
|
||||
warmup_iter: -1 # no warm up
|
||||
lr_steps: [20000, 40000, 60000, 80000]
|
||||
lr_steps: [5000, 20000, 40000, 60000]
|
||||
lr_gamma: 0.5
|
||||
mega_batch_factor: 8
|
||||
|
||||
pixel_criterion: l1
|
||||
pixel_weight: !!float 1e-2
|
||||
feature_criterion: l1
|
||||
feature_weight: 1
|
||||
feature_weight_decay: .98
|
||||
feature_weight: 0
|
||||
feature_weight_decay: .9
|
||||
feature_weight_decay_steps: 500
|
||||
feature_weight_minimum: .1
|
||||
gan_type: gan # gan | ragan
|
||||
gan_weight: !!float 5e-3
|
||||
gan_weight: 1
|
||||
|
||||
D_update_ratio: 1
|
||||
D_init_iters: 0
|
||||
D_init_iters: -1
|
||||
|
||||
manual_seed: 10
|
||||
val_freq: !!float 5e2
|
||||
|
|
Loading…
Reference in New Issue
Block a user