Add a feature-based validation test

This commit is contained in:
James Betker 2020-07-03 15:18:57 -06:00
parent 703dec4472
commit da4335c25e
2 changed files with 21 additions and 7 deletions

View File

@ -449,6 +449,14 @@ class SRGANModel(BaseModel):
# Just a note: this intentionally includes the swap model in the list of possibilities. # Just a note: this intentionally includes the swap model in the list of possibilities.
return previous_models[random.randint(0, len(previous_models)-1)] return previous_models[random.randint(0, len(previous_models)-1)]
def compute_fea_loss(self, real, fake):
with torch.no_grad():
real = real.unsqueeze(dim=0)
fake = fake.unsqueeze(dim=0)
real_fea = self.netF(real).detach()
fake_fea = self.netF(fake)
return self.cri_fea(fake_fea, real_fea).item()
# Called before verification/checkpoint to ensure we're using the real models and not a swapout variant. # Called before verification/checkpoint to ensure we're using the real models and not a swapout variant.
def force_restore_swapout(self): def force_restore_swapout(self):
if self.swapout_D_duration > 0: if self.swapout_D_duration > 0:

View File

@ -162,7 +162,7 @@ def main():
current_step = resume_state['iter'] current_step = resume_state['iter']
model.resume_training(resume_state) # handle optimizers and schedulers model.resume_training(resume_state) # handle optimizers and schedulers
else: else:
current_step = 0 current_step = -1
start_epoch = 0 start_epoch = 0
#### training #### training
@ -221,6 +221,7 @@ def main():
# does not support multi-GPU validation # does not support multi-GPU validation
pbar = util.ProgressBar(len(val_loader) * val_batch_sz) pbar = util.ProgressBar(len(val_loader) * val_batch_sz)
avg_psnr = 0. avg_psnr = 0.
avg_fea_loss = 0.
idx = 0 idx = 0
colab_imgs_to_copy = [] colab_imgs_to_copy = []
for val_data in val_loader: for val_data in val_loader:
@ -245,10 +246,13 @@ def main():
if colab_mode: if colab_mode:
colab_imgs_to_copy.append(save_img_path) colab_imgs_to_copy.append(save_img_path)
# calculate PSNR # calculate PSNR (Naw - don't do that. PSNR sucks)
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) #sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
avg_psnr += util.calculate_psnr(sr_img, gt_img) #avg_psnr += util.calculate_psnr(sr_img, gt_img)
pbar.update('Test {}'.format(img_name)) #pbar.update('Test {}'.format(img_name))
# calculate fea loss
avg_fea_loss += model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
if colab_mode: if colab_mode:
util.copy_files_to_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], util.copy_files_to_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
@ -256,12 +260,14 @@ def main():
os.path.join(opt['remote_path'], 'val_images', img_base_name)) os.path.join(opt['remote_path'], 'val_images', img_base_name))
avg_psnr = avg_psnr / idx avg_psnr = avg_psnr / idx
avg_fea_loss = avg_fea_loss / idx
# log # log
logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr)) logger.info('# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(avg_psnr, avg_fea_loss))
# tensorboard logger # tensorboard logger
if opt['use_tb_logger'] and 'debug' not in opt['name']: if opt['use_tb_logger'] and 'debug' not in opt['name']:
tb_logger.add_scalar('psnr', avg_psnr, current_step) tb_logger.add_scalar('val_psnr', avg_psnr, current_step)
tb_logger.add_scalar('val_fea', avg_fea_loss, current_step)
#### save models and training states #### save models and training states
if current_step % opt['logger']['save_checkpoint_freq'] == 0: if current_step % opt['logger']['save_checkpoint_freq'] == 0: