This commit is contained in:
James Betker 2020-10-21 11:08:21 -06:00
parent 71c3820d2d
commit b54de69153
3 changed files with 19 additions and 23 deletions

View File

@ -2,5 +2,7 @@
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
<mapping directory="$PROJECT_DIR$/codes/models/flownet2" vcs="Git" />
<mapping directory="$PROJECT_DIR$/codes/switched_conv" vcs="Git" />
</component>
</project>

View File

@ -242,7 +242,7 @@ class TiledDataset(data.Dataset):
h, w, c = img.shape
# Uncomment to filter any image that doesnt meet a threshold size.
if min(h,w) < 512:
if min(h,w) < 1024:
return None
left = 0
right = w

View File

@ -8,9 +8,9 @@ from tqdm import tqdm
import torch
from data.data_sampler import DistIterSampler
from models.ExtensibleTrainer import ExtensibleTrainer
from utils import util, options as option
from data import create_dataloader, create_dataset
from models.ExtensibleTrainer import ExtensibleTrainer
from time import time
@ -19,7 +19,7 @@ def init_dist(backend='nccl', **kwargs):
import torch.distributed as dist
import torch.multiprocessing as mp
"""initialization for distributneed training"""
"""initialization for distributed training"""
if mp.get_start_method(allow_none=True) != 'spawn':
mp.set_start_method('spawn')
rank = int(os.environ['RANK'])
@ -122,7 +122,7 @@ def main():
torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = True
#torch.autograd.set_detect_anomaly(True)
# torch.autograd.set_detect_anomaly(True)
# Save the compiled opt dict to the global loaded_options variable.
util.loaded_options = opt
@ -170,6 +170,8 @@ def main():
else:
current_step = -1 if 'start_step' not in opt.keys() else opt['start_step']
start_epoch = 0
if 'force_start_step' in opt.keys():
current_step = opt['force_start_step']
#### training
logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
@ -185,10 +187,6 @@ def main():
print("Data fetch: %f" % (time() - _t))
_t = time()
#tb_logger.add_graph(model.netsG['generator'].module, [train_data['LQ'].to('cuda'),
# train_data['lq_fullsize_ref'].float().to('cuda'),
# train_data['lq_center'].to('cuda')])
current_step += 1
if current_step > total_iters:
break
@ -241,9 +239,6 @@ def main():
#### validation
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan', 'extensibletrainer'] and rank <= 0: # image restoration validation
model.force_restore_swapout()
val_batch_sz = 1 if 'batch_size' not in opt['datasets']['val'].keys() else opt['datasets']['val']['batch_size']
# does not support multi-GPU validation
avg_psnr = 0.
avg_fea_loss = 0.
idx = 0
@ -263,23 +258,22 @@ def main():
if visuals is None:
continue
if colab_mode:
colab_imgs_to_copy.append(save_img_path)
# calculate PSNR
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
#gt_img = util.tensor2img(visuals['GT'][b]) # uint8
gt_img = util.tensor2img(visuals['GT'][b]) # uint8
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
avg_psnr += util.calculate_psnr(sr_img, gt_img)
# calculate fea loss
avg_fea_loss += model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
# Save SR images for reference
img_base_name = '{:s}_{:d}.png'.format(img_name, current_step)
save_img_path = os.path.join(img_dir, img_base_name)
util.save_img(sr_img, save_img_path)
if colab_mode:
colab_imgs_to_copy.append(save_img_path)
# calculate PSNR (Naw - don't do that. PSNR sucks)
#sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
#avg_psnr += util.calculate_psnr(sr_img, gt_img)
#pbar.update('Test {}'.format(img_name))
# calculate fea loss
avg_fea_loss += model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
if colab_mode:
util.copy_files_to_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
@ -293,7 +287,7 @@ def main():
logger.info('# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(avg_psnr, avg_fea_loss))
# tensorboard logger
if opt['use_tb_logger'] and 'debug' not in opt['name'] and rank <= 0:
#tb_logger.add_scalar('val_psnr', avg_psnr, current_step)
tb_logger.add_scalar('val_psnr', avg_psnr, current_step)
tb_logger.add_scalar('val_fea', avg_fea_loss, current_step)
if rank <= 0: