Add a feature-based validation test
This commit is contained in:
parent
703dec4472
commit
da4335c25e
|
@ -449,6 +449,14 @@ class SRGANModel(BaseModel):
|
|||
# Just a note: this intentionally includes the swap model in the list of possibilities.
|
||||
return previous_models[random.randint(0, len(previous_models)-1)]
|
||||
|
||||
def compute_fea_loss(self, real, fake):
|
||||
with torch.no_grad():
|
||||
real = real.unsqueeze(dim=0)
|
||||
fake = fake.unsqueeze(dim=0)
|
||||
real_fea = self.netF(real).detach()
|
||||
fake_fea = self.netF(fake)
|
||||
return self.cri_fea(fake_fea, real_fea).item()
|
||||
|
||||
# Called before verification/checkpoint to ensure we're using the real models and not a swapout variant.
|
||||
def force_restore_swapout(self):
|
||||
if self.swapout_D_duration > 0:
|
||||
|
|
|
@ -162,7 +162,7 @@ def main():
|
|||
current_step = resume_state['iter']
|
||||
model.resume_training(resume_state) # handle optimizers and schedulers
|
||||
else:
|
||||
current_step = 0
|
||||
current_step = -1
|
||||
start_epoch = 0
|
||||
|
||||
#### training
|
||||
|
@ -221,6 +221,7 @@ def main():
|
|||
# does not support multi-GPU validation
|
||||
pbar = util.ProgressBar(len(val_loader) * val_batch_sz)
|
||||
avg_psnr = 0.
|
||||
avg_fea_loss = 0.
|
||||
idx = 0
|
||||
colab_imgs_to_copy = []
|
||||
for val_data in val_loader:
|
||||
|
@ -245,10 +246,13 @@ def main():
|
|||
if colab_mode:
|
||||
colab_imgs_to_copy.append(save_img_path)
|
||||
|
||||
# calculate PSNR
|
||||
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
||||
avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
||||
pbar.update('Test {}'.format(img_name))
|
||||
# calculate PSNR (Naw - don't do that. PSNR sucks)
|
||||
#sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
||||
#avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
||||
#pbar.update('Test {}'.format(img_name))
|
||||
|
||||
# calculate fea loss
|
||||
avg_fea_loss += model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
|
||||
|
||||
if colab_mode:
|
||||
util.copy_files_to_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
|
||||
|
@ -256,12 +260,14 @@ def main():
|
|||
os.path.join(opt['remote_path'], 'val_images', img_base_name))
|
||||
|
||||
avg_psnr = avg_psnr / idx
|
||||
avg_fea_loss = avg_fea_loss / idx
|
||||
|
||||
# log
|
||||
logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
|
||||
logger.info('# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(avg_psnr, avg_fea_loss))
|
||||
# tensorboard logger
|
||||
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
||||
tb_logger.add_scalar('psnr', avg_psnr, current_step)
|
||||
tb_logger.add_scalar('val_psnr', avg_psnr, current_step)
|
||||
tb_logger.add_scalar('val_fea', avg_fea_loss, current_step)
|
||||
|
||||
#### save models and training states
|
||||
if current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
||||
|
|
Loading…
Reference in New Issue
Block a user