Allow validating in batches, remove val size limit
This commit is contained in:
parent
90125f5bed
commit
726d1913ac
|
@ -448,13 +448,13 @@ class SRGANModel(BaseModel):
|
||||||
|
|
||||||
def get_current_visuals(self, need_GT=True):
|
def get_current_visuals(self, need_GT=True):
|
||||||
out_dict = OrderedDict()
|
out_dict = OrderedDict()
|
||||||
out_dict['LQ'] = self.var_L[0].detach()[0].float().cpu()
|
out_dict['LQ'] = self.var_L[0].detach().float().cpu()
|
||||||
gen_batch = self.fake_GenOut[0]
|
gen_batch = self.fake_GenOut[0]
|
||||||
if isinstance(gen_batch, tuple):
|
if isinstance(gen_batch, tuple):
|
||||||
gen_batch = gen_batch[0]
|
gen_batch = gen_batch[0]
|
||||||
out_dict['rlt'] = gen_batch.detach()[0].float().cpu()
|
out_dict['rlt'] = gen_batch.detach().float().cpu()
|
||||||
if need_GT:
|
if need_GT:
|
||||||
out_dict['GT'] = self.var_H[0].detach()[0].float().cpu()
|
out_dict['GT'] = self.var_H[0].detach().float().cpu()
|
||||||
return out_dict
|
return out_dict
|
||||||
|
|
||||||
def print_network(self):
|
def print_network(self):
|
||||||
|
|
|
@ -144,10 +144,10 @@ class SRModel(BaseModel):
|
||||||
|
|
||||||
def get_current_visuals(self, need_GT=True):
|
def get_current_visuals(self, need_GT=True):
|
||||||
out_dict = OrderedDict()
|
out_dict = OrderedDict()
|
||||||
out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
|
out_dict['LQ'] = self.var_L.detach().float().cpu()
|
||||||
out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
|
out_dict['rlt'] = self.fake_H.detach().float().cpu()
|
||||||
if need_GT:
|
if need_GT:
|
||||||
out_dict['GT'] = self.real_H.detach()[0].float().cpu()
|
out_dict['GT'] = self.real_H.detach().float().cpu()
|
||||||
return out_dict
|
return out_dict
|
||||||
|
|
||||||
def print_network(self):
|
def print_network(self):
|
||||||
|
|
126
codes/train.py
126
codes/train.py
|
@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs):
|
||||||
def main():
|
def main():
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imset_pre_rrdb.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_cifar_rrdb.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
@ -204,38 +204,38 @@ def main():
|
||||||
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
|
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
|
||||||
if opt['model'] in ['sr', 'srgan', 'corruptgan'] and rank <= 0: # image restoration validation
|
if opt['model'] in ['sr', 'srgan', 'corruptgan'] and rank <= 0: # image restoration validation
|
||||||
model.force_restore_swapout()
|
model.force_restore_swapout()
|
||||||
|
val_batch_sz = 1 if 'batch_size' not in opt['datasets']['val'].keys() else opt['datasets']['val']['batch_size']
|
||||||
# does not support multi-GPU validation
|
# does not support multi-GPU validation
|
||||||
pbar = util.ProgressBar(len(val_loader))
|
pbar = util.ProgressBar(len(val_loader) * val_batch_sz)
|
||||||
avg_psnr = 0.
|
avg_psnr = 0.
|
||||||
idx = 0
|
idx = 0
|
||||||
colab_imgs_to_copy = []
|
colab_imgs_to_copy = []
|
||||||
for val_data in val_loader:
|
for val_data in val_loader:
|
||||||
idx += 1
|
idx += 1
|
||||||
if idx >= 20:
|
for b in range(len(val_data['LQ_path'])):
|
||||||
break
|
img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][b]))[0]
|
||||||
img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0]
|
img_dir = os.path.join(opt['path']['val_images'], img_name)
|
||||||
img_dir = os.path.join(opt['path']['val_images'], img_name)
|
util.mkdir(img_dir)
|
||||||
util.mkdir(img_dir)
|
|
||||||
|
|
||||||
model.feed_data(val_data)
|
model.feed_data(val_data)
|
||||||
model.test()
|
model.test()
|
||||||
|
|
||||||
visuals = model.get_current_visuals()
|
visuals = model.get_current_visuals()
|
||||||
|
|
||||||
sr_img = util.tensor2img(visuals['rlt']) # uint8
|
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
|
||||||
gt_img = util.tensor2img(visuals['GT']) # uint8
|
gt_img = util.tensor2img(visuals['GT'][b]) # uint8
|
||||||
|
|
||||||
# Save SR images for reference
|
# Save SR images for reference
|
||||||
img_base_name = '{:s}_{:d}.png'.format(img_name, current_step)
|
img_base_name = '{:s}_{:d}.png'.format(img_name, current_step)
|
||||||
save_img_path = os.path.join(img_dir, img_base_name)
|
save_img_path = os.path.join(img_dir, img_base_name)
|
||||||
util.save_img(sr_img, save_img_path)
|
util.save_img(sr_img, save_img_path)
|
||||||
if colab_mode:
|
if colab_mode:
|
||||||
colab_imgs_to_copy.append(save_img_path)
|
colab_imgs_to_copy.append(save_img_path)
|
||||||
|
|
||||||
# calculate PSNR
|
# calculate PSNR
|
||||||
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
||||||
avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
||||||
pbar.update('Test {}'.format(img_name))
|
pbar.update('Test {}'.format(img_name))
|
||||||
|
|
||||||
if colab_mode:
|
if colab_mode:
|
||||||
util.copy_files_to_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
|
util.copy_files_to_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
|
||||||
|
@ -249,88 +249,6 @@ def main():
|
||||||
# tensorboard logger
|
# tensorboard logger
|
||||||
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
||||||
tb_logger.add_scalar('psnr', avg_psnr, current_step)
|
tb_logger.add_scalar('psnr', avg_psnr, current_step)
|
||||||
else: # video restoration validation
|
|
||||||
if opt['dist']:
|
|
||||||
# multi-GPU testing
|
|
||||||
psnr_rlt = {} # with border and center frames
|
|
||||||
if rank == 0:
|
|
||||||
pbar = util.ProgressBar(len(val_set))
|
|
||||||
for idx in range(rank, len(val_set), world_size):
|
|
||||||
val_data = val_set[idx]
|
|
||||||
val_data['LQs'].unsqueeze_(0)
|
|
||||||
val_data['GT'].unsqueeze_(0)
|
|
||||||
folder = val_data['folder']
|
|
||||||
idx_d, max_idx = val_data['idx'].split('/')
|
|
||||||
idx_d, max_idx = int(idx_d), int(max_idx)
|
|
||||||
if psnr_rlt.get(folder, None) is None:
|
|
||||||
psnr_rlt[folder] = torch.zeros(max_idx, dtype=torch.float32,
|
|
||||||
device='cuda')
|
|
||||||
# tmp = torch.zeros(max_idx, dtype=torch.float32, device='cuda')
|
|
||||||
model.feed_data(val_data)
|
|
||||||
model.test()
|
|
||||||
visuals = model.get_current_visuals()
|
|
||||||
rlt_img = util.tensor2img(visuals['rlt']) # uint8
|
|
||||||
gt_img = util.tensor2img(visuals['GT']) # uint8
|
|
||||||
# calculate PSNR
|
|
||||||
psnr_rlt[folder][idx_d] = util.calculate_psnr(rlt_img, gt_img)
|
|
||||||
|
|
||||||
if rank == 0:
|
|
||||||
for _ in range(world_size):
|
|
||||||
pbar.update('Test {} - {}/{}'.format(folder, idx_d, max_idx))
|
|
||||||
# # collect data
|
|
||||||
for _, v in psnr_rlt.items():
|
|
||||||
dist.reduce(v, 0)
|
|
||||||
dist.barrier()
|
|
||||||
|
|
||||||
if rank == 0:
|
|
||||||
psnr_rlt_avg = {}
|
|
||||||
psnr_total_avg = 0.
|
|
||||||
for k, v in psnr_rlt.items():
|
|
||||||
psnr_rlt_avg[k] = torch.mean(v).cpu().item()
|
|
||||||
psnr_total_avg += psnr_rlt_avg[k]
|
|
||||||
psnr_total_avg /= len(psnr_rlt)
|
|
||||||
log_s = '# Validation # PSNR: {:.4e}:'.format(psnr_total_avg)
|
|
||||||
for k, v in psnr_rlt_avg.items():
|
|
||||||
log_s += ' {}: {:.4e}'.format(k, v)
|
|
||||||
logger.info(log_s)
|
|
||||||
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
|
||||||
tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step)
|
|
||||||
for k, v in psnr_rlt_avg.items():
|
|
||||||
tb_logger.add_scalar(k, v, current_step)
|
|
||||||
else:
|
|
||||||
pbar = util.ProgressBar(len(val_loader))
|
|
||||||
psnr_rlt = {} # with border and center frames
|
|
||||||
psnr_rlt_avg = {}
|
|
||||||
psnr_total_avg = 0.
|
|
||||||
for val_data in val_loader:
|
|
||||||
folder = val_data['folder'][0]
|
|
||||||
idx_d = val_data['idx'].item()
|
|
||||||
# border = val_data['border'].item()
|
|
||||||
if psnr_rlt.get(folder, None) is None:
|
|
||||||
psnr_rlt[folder] = []
|
|
||||||
|
|
||||||
model.feed_data(val_data)
|
|
||||||
model.test()
|
|
||||||
visuals = model.get_current_visuals()
|
|
||||||
rlt_img = util.tensor2img(visuals['rlt']) # uint8
|
|
||||||
gt_img = util.tensor2img(visuals['GT']) # uint8
|
|
||||||
|
|
||||||
# calculate PSNR
|
|
||||||
psnr = util.calculate_psnr(rlt_img, gt_img)
|
|
||||||
psnr_rlt[folder].append(psnr)
|
|
||||||
pbar.update('Test {} - {}'.format(folder, idx_d))
|
|
||||||
for k, v in psnr_rlt.items():
|
|
||||||
psnr_rlt_avg[k] = sum(v) / len(v)
|
|
||||||
psnr_total_avg += psnr_rlt_avg[k]
|
|
||||||
psnr_total_avg /= len(psnr_rlt)
|
|
||||||
log_s = '# Validation # PSNR: {:.4e}:'.format(psnr_total_avg)
|
|
||||||
for k, v in psnr_rlt_avg.items():
|
|
||||||
log_s += ' {}: {:.4e}'.format(k, v)
|
|
||||||
logger.info(log_s)
|
|
||||||
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
|
||||||
tb_logger.add_scalar('psnr_avg', psnr_total_avg, current_step)
|
|
||||||
for k, v in psnr_rlt_avg.items():
|
|
||||||
tb_logger.add_scalar(k, v, current_step)
|
|
||||||
|
|
||||||
#### save models and training states
|
#### save models and training states
|
||||||
if current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
if current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user