Add a feature-based validation test
This commit is contained in:
parent
703dec4472
commit
da4335c25e
|
@ -449,6 +449,14 @@ class SRGANModel(BaseModel):
|
||||||
# Just a note: this intentionally includes the swap model in the list of possibilities.
|
# Just a note: this intentionally includes the swap model in the list of possibilities.
|
||||||
return previous_models[random.randint(0, len(previous_models)-1)]
|
return previous_models[random.randint(0, len(previous_models)-1)]
|
||||||
|
|
||||||
|
def compute_fea_loss(self, real, fake):
|
||||||
|
with torch.no_grad():
|
||||||
|
real = real.unsqueeze(dim=0)
|
||||||
|
fake = fake.unsqueeze(dim=0)
|
||||||
|
real_fea = self.netF(real).detach()
|
||||||
|
fake_fea = self.netF(fake)
|
||||||
|
return self.cri_fea(fake_fea, real_fea).item()
|
||||||
|
|
||||||
# Called before verification/checkpoint to ensure we're using the real models and not a swapout variant.
|
# Called before verification/checkpoint to ensure we're using the real models and not a swapout variant.
|
||||||
def force_restore_swapout(self):
|
def force_restore_swapout(self):
|
||||||
if self.swapout_D_duration > 0:
|
if self.swapout_D_duration > 0:
|
||||||
|
|
|
@ -162,7 +162,7 @@ def main():
|
||||||
current_step = resume_state['iter']
|
current_step = resume_state['iter']
|
||||||
model.resume_training(resume_state) # handle optimizers and schedulers
|
model.resume_training(resume_state) # handle optimizers and schedulers
|
||||||
else:
|
else:
|
||||||
current_step = 0
|
current_step = -1
|
||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
|
|
||||||
#### training
|
#### training
|
||||||
|
@ -221,6 +221,7 @@ def main():
|
||||||
# does not support multi-GPU validation
|
# does not support multi-GPU validation
|
||||||
pbar = util.ProgressBar(len(val_loader) * val_batch_sz)
|
pbar = util.ProgressBar(len(val_loader) * val_batch_sz)
|
||||||
avg_psnr = 0.
|
avg_psnr = 0.
|
||||||
|
avg_fea_loss = 0.
|
||||||
idx = 0
|
idx = 0
|
||||||
colab_imgs_to_copy = []
|
colab_imgs_to_copy = []
|
||||||
for val_data in val_loader:
|
for val_data in val_loader:
|
||||||
|
@ -245,10 +246,13 @@ def main():
|
||||||
if colab_mode:
|
if colab_mode:
|
||||||
colab_imgs_to_copy.append(save_img_path)
|
colab_imgs_to_copy.append(save_img_path)
|
||||||
|
|
||||||
# calculate PSNR
|
# calculate PSNR (Naw - don't do that. PSNR sucks)
|
||||||
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
#sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
||||||
avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
#avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
||||||
pbar.update('Test {}'.format(img_name))
|
#pbar.update('Test {}'.format(img_name))
|
||||||
|
|
||||||
|
# calculate fea loss
|
||||||
|
avg_fea_loss += model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
|
||||||
|
|
||||||
if colab_mode:
|
if colab_mode:
|
||||||
util.copy_files_to_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
|
util.copy_files_to_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
|
||||||
|
@ -256,12 +260,14 @@ def main():
|
||||||
os.path.join(opt['remote_path'], 'val_images', img_base_name))
|
os.path.join(opt['remote_path'], 'val_images', img_base_name))
|
||||||
|
|
||||||
avg_psnr = avg_psnr / idx
|
avg_psnr = avg_psnr / idx
|
||||||
|
avg_fea_loss = avg_fea_loss / idx
|
||||||
|
|
||||||
# log
|
# log
|
||||||
logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
|
logger.info('# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(avg_psnr, avg_fea_loss))
|
||||||
# tensorboard logger
|
# tensorboard logger
|
||||||
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
||||||
tb_logger.add_scalar('psnr', avg_psnr, current_step)
|
tb_logger.add_scalar('val_psnr', avg_psnr, current_step)
|
||||||
|
tb_logger.add_scalar('val_fea', avg_fea_loss, current_step)
|
||||||
|
|
||||||
#### save models and training states
|
#### save models and training states
|
||||||
if current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
if current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user