forked from mrq/DL-Art-School
Support validation over a custom injector
Also re-enable PSNR
This commit is contained in:
parent
ffad0e0422
commit
981d64413b
|
@ -10,6 +10,7 @@ from torch.nn.parallel.distributed import DistributedDataParallel
|
|||
import models.lr_scheduler as lr_scheduler
|
||||
import models.networks as networks
|
||||
from models.base_model import BaseModel
|
||||
from models.steps.injectors import create_injector
|
||||
from models.steps.steps import ConfigurableStep
|
||||
from models.experiments.experiments import get_experiment_for_name
|
||||
import torchvision.utils as utils
|
||||
|
@ -155,7 +156,7 @@ class ExtensibleTrainer(BaseModel):
|
|||
o.zero_grad()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.lq = torch.chunk(data['LQ'].to(self.device), chunks=self.mega_batch_factor, dim=0)
|
||||
self.lq = [t.to(self.device) for t in torch.chunk(data['LQ'], chunks=self.mega_batch_factor, dim=0)]
|
||||
if need_GT:
|
||||
self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)]
|
||||
input_ref = data['ref'] if 'ref' in data.keys() else data['GT']
|
||||
|
@ -260,19 +261,29 @@ class ExtensibleTrainer(BaseModel):
|
|||
net.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
# Iterate through the steps, performing them one at a time.
|
||||
state = self.dstate
|
||||
for step_num, s in enumerate(self.steps):
|
||||
ns = s.do_forward_backward(state, 0, step_num, train=False)
|
||||
for k, v in ns.items():
|
||||
state[k] = [v]
|
||||
# This can happen one of two ways: Either a 'validation injector' is provided, in which case we run that.
|
||||
# Or, we run the entire chain of steps in "train" mode and use eval.output_state.
|
||||
if 'injector' in self.opt['eval'].keys():
|
||||
# Need to move from mega_batch mode to batch mode (remove chunks)
|
||||
state = {}
|
||||
for k, v in self.dstate.items():
|
||||
state[k] = v[0]
|
||||
inj = create_injector(self.opt['eval']['injector'], self.env)
|
||||
state.update(inj(state))
|
||||
else:
|
||||
# Iterate through the steps, performing them one at a time.
|
||||
state = self.dstate
|
||||
for step_num, s in enumerate(self.steps):
|
||||
ns = s.do_forward_backward(state, 0, step_num, train=False)
|
||||
for k, v in ns.items():
|
||||
state[k] = [v]
|
||||
|
||||
self.eval_state = {}
|
||||
for k, v in state.items():
|
||||
self.eval_state[k] = [s.detach().cpu() if isinstance(s, torch.Tensor) else s for s in v]
|
||||
|
||||
# For backwards compatibility..
|
||||
self.fake_H = self.eval_state[self.opt['eval']['output_state']][0].float().cpu()
|
||||
if isinstance(v, list):
|
||||
self.eval_state[k] = [s.detach().cpu() if isinstance(s, torch.Tensor) else s for s in v]
|
||||
else:
|
||||
self.eval_state[k] = [v.detach().cpu() if isinstance(v, torch.Tensor) else v]
|
||||
|
||||
for net in self.netsG.values():
|
||||
net.train()
|
||||
|
|
|
@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_prog_imgset_chained.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_spsr7.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
@ -185,10 +185,6 @@ def main():
|
|||
print("Data fetch: %f" % (time() - _t))
|
||||
_t = time()
|
||||
|
||||
#tb_logger.add_graph(model.netsG['generator'].module, [train_data['LQ'].to('cuda'),
|
||||
# train_data['lq_fullsize_ref'].float().to('cuda'),
|
||||
# train_data['lq_center'].to('cuda')])
|
||||
|
||||
current_step += 1
|
||||
if current_step > total_iters:
|
||||
break
|
||||
|
@ -241,9 +237,6 @@ def main():
|
|||
#### validation
|
||||
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
|
||||
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan', 'extensibletrainer'] and rank <= 0: # image restoration validation
|
||||
model.force_restore_swapout()
|
||||
val_batch_sz = 1 if 'batch_size' not in opt['datasets']['val'].keys() else opt['datasets']['val']['batch_size']
|
||||
# does not support multi-GPU validation
|
||||
avg_psnr = 0.
|
||||
avg_fea_loss = 0.
|
||||
idx = 0
|
||||
|
@ -263,23 +256,22 @@ def main():
|
|||
if visuals is None:
|
||||
continue
|
||||
|
||||
if colab_mode:
|
||||
colab_imgs_to_copy.append(save_img_path)
|
||||
|
||||
# calculate PSNR
|
||||
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
|
||||
#gt_img = util.tensor2img(visuals['GT'][b]) # uint8
|
||||
gt_img = util.tensor2img(visuals['GT'][b]) # uint8
|
||||
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
||||
avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
||||
|
||||
# calculate fea loss
|
||||
avg_fea_loss += model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
|
||||
|
||||
# Save SR images for reference
|
||||
img_base_name = '{:s}_{:d}.png'.format(img_name, current_step)
|
||||
save_img_path = os.path.join(img_dir, img_base_name)
|
||||
util.save_img(sr_img, save_img_path)
|
||||
if colab_mode:
|
||||
colab_imgs_to_copy.append(save_img_path)
|
||||
|
||||
# calculate PSNR (Naw - don't do that. PSNR sucks)
|
||||
#sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
||||
#avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
||||
#pbar.update('Test {}'.format(img_name))
|
||||
|
||||
# calculate fea loss
|
||||
avg_fea_loss += model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
|
||||
|
||||
if colab_mode:
|
||||
util.copy_files_to_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
|
||||
|
@ -293,7 +285,7 @@ def main():
|
|||
logger.info('# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(avg_psnr, avg_fea_loss))
|
||||
# tensorboard logger
|
||||
if opt['use_tb_logger'] and 'debug' not in opt['name'] and rank <= 0:
|
||||
#tb_logger.add_scalar('val_psnr', avg_psnr, current_step)
|
||||
tb_logger.add_scalar('val_psnr', avg_psnr, current_step)
|
||||
tb_logger.add_scalar('val_fea', avg_fea_loss, current_step)
|
||||
|
||||
if rank <= 0:
|
||||
|
|
Loading…
Reference in New Issue
Block a user