Misc
This commit is contained in:
parent
71c3820d2d
commit
b54de69153
|
@ -2,5 +2,7 @@
|
|||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/codes/models/flownet2" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/codes/switched_conv" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
|
@ -242,7 +242,7 @@ class TiledDataset(data.Dataset):
|
|||
|
||||
h, w, c = img.shape
|
||||
# Uncomment to filter any image that doesnt meet a threshold size.
|
||||
if min(h,w) < 512:
|
||||
if min(h,w) < 1024:
|
||||
return None
|
||||
left = 0
|
||||
right = w
|
||||
|
|
|
@ -8,9 +8,9 @@ from tqdm import tqdm
|
|||
import torch
|
||||
from data.data_sampler import DistIterSampler
|
||||
|
||||
from models.ExtensibleTrainer import ExtensibleTrainer
|
||||
from utils import util, options as option
|
||||
from data import create_dataloader, create_dataset
|
||||
from models.ExtensibleTrainer import ExtensibleTrainer
|
||||
from time import time
|
||||
|
||||
|
||||
|
@ -19,7 +19,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
"""initialization for distributneed training"""
|
||||
"""initialization for distributed training"""
|
||||
if mp.get_start_method(allow_none=True) != 'spawn':
|
||||
mp.set_start_method('spawn')
|
||||
rank = int(os.environ['RANK'])
|
||||
|
@ -122,7 +122,7 @@ def main():
|
|||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
#torch.autograd.set_detect_anomaly(True)
|
||||
# torch.autograd.set_detect_anomaly(True)
|
||||
|
||||
# Save the compiled opt dict to the global loaded_options variable.
|
||||
util.loaded_options = opt
|
||||
|
@ -170,6 +170,8 @@ def main():
|
|||
else:
|
||||
current_step = -1 if 'start_step' not in opt.keys() else opt['start_step']
|
||||
start_epoch = 0
|
||||
if 'force_start_step' in opt.keys():
|
||||
current_step = opt['force_start_step']
|
||||
|
||||
#### training
|
||||
logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
|
||||
|
@ -185,10 +187,6 @@ def main():
|
|||
print("Data fetch: %f" % (time() - _t))
|
||||
_t = time()
|
||||
|
||||
#tb_logger.add_graph(model.netsG['generator'].module, [train_data['LQ'].to('cuda'),
|
||||
# train_data['lq_fullsize_ref'].float().to('cuda'),
|
||||
# train_data['lq_center'].to('cuda')])
|
||||
|
||||
current_step += 1
|
||||
if current_step > total_iters:
|
||||
break
|
||||
|
@ -241,9 +239,6 @@ def main():
|
|||
#### validation
|
||||
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
|
||||
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan', 'extensibletrainer'] and rank <= 0: # image restoration validation
|
||||
model.force_restore_swapout()
|
||||
val_batch_sz = 1 if 'batch_size' not in opt['datasets']['val'].keys() else opt['datasets']['val']['batch_size']
|
||||
# does not support multi-GPU validation
|
||||
avg_psnr = 0.
|
||||
avg_fea_loss = 0.
|
||||
idx = 0
|
||||
|
@ -263,23 +258,22 @@ def main():
|
|||
if visuals is None:
|
||||
continue
|
||||
|
||||
if colab_mode:
|
||||
colab_imgs_to_copy.append(save_img_path)
|
||||
|
||||
# calculate PSNR
|
||||
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
|
||||
#gt_img = util.tensor2img(visuals['GT'][b]) # uint8
|
||||
gt_img = util.tensor2img(visuals['GT'][b]) # uint8
|
||||
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
||||
avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
||||
|
||||
# calculate fea loss
|
||||
avg_fea_loss += model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
|
||||
|
||||
# Save SR images for reference
|
||||
img_base_name = '{:s}_{:d}.png'.format(img_name, current_step)
|
||||
save_img_path = os.path.join(img_dir, img_base_name)
|
||||
util.save_img(sr_img, save_img_path)
|
||||
if colab_mode:
|
||||
colab_imgs_to_copy.append(save_img_path)
|
||||
|
||||
# calculate PSNR (Naw - don't do that. PSNR sucks)
|
||||
#sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
||||
#avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
||||
#pbar.update('Test {}'.format(img_name))
|
||||
|
||||
# calculate fea loss
|
||||
avg_fea_loss += model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
|
||||
|
||||
if colab_mode:
|
||||
util.copy_files_to_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
|
||||
|
@ -293,7 +287,7 @@ def main():
|
|||
logger.info('# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(avg_psnr, avg_fea_loss))
|
||||
# tensorboard logger
|
||||
if opt['use_tb_logger'] and 'debug' not in opt['name'] and rank <= 0:
|
||||
#tb_logger.add_scalar('val_psnr', avg_psnr, current_step)
|
||||
tb_logger.add_scalar('val_psnr', avg_psnr, current_step)
|
||||
tb_logger.add_scalar('val_fea', avg_fea_loss, current_step)
|
||||
|
||||
if rank <= 0:
|
||||
|
|
Loading…
Reference in New Issue
Block a user