forked from mrq/DL-Art-School
Misc
This commit is contained in:
parent
71c3820d2d
commit
b54de69153
|
@ -2,5 +2,7 @@
|
||||||
<project version="4">
|
<project version="4">
|
||||||
<component name="VcsDirectoryMappings">
|
<component name="VcsDirectoryMappings">
|
||||||
<mapping directory="" vcs="Git" />
|
<mapping directory="" vcs="Git" />
|
||||||
|
<mapping directory="$PROJECT_DIR$/codes/models/flownet2" vcs="Git" />
|
||||||
|
<mapping directory="$PROJECT_DIR$/codes/switched_conv" vcs="Git" />
|
||||||
</component>
|
</component>
|
||||||
</project>
|
</project>
|
|
@ -242,7 +242,7 @@ class TiledDataset(data.Dataset):
|
||||||
|
|
||||||
h, w, c = img.shape
|
h, w, c = img.shape
|
||||||
# Uncomment to filter any image that doesnt meet a threshold size.
|
# Uncomment to filter any image that doesnt meet a threshold size.
|
||||||
if min(h,w) < 512:
|
if min(h,w) < 1024:
|
||||||
return None
|
return None
|
||||||
left = 0
|
left = 0
|
||||||
right = w
|
right = w
|
||||||
|
|
|
@ -8,9 +8,9 @@ from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
from data.data_sampler import DistIterSampler
|
from data.data_sampler import DistIterSampler
|
||||||
|
|
||||||
from models.ExtensibleTrainer import ExtensibleTrainer
|
|
||||||
from utils import util, options as option
|
from utils import util, options as option
|
||||||
from data import create_dataloader, create_dataset
|
from data import create_dataloader, create_dataset
|
||||||
|
from models.ExtensibleTrainer import ExtensibleTrainer
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ def init_dist(backend='nccl', **kwargs):
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
"""initialization for distributneed training"""
|
"""initialization for distributed training"""
|
||||||
if mp.get_start_method(allow_none=True) != 'spawn':
|
if mp.get_start_method(allow_none=True) != 'spawn':
|
||||||
mp.set_start_method('spawn')
|
mp.set_start_method('spawn')
|
||||||
rank = int(os.environ['RANK'])
|
rank = int(os.environ['RANK'])
|
||||||
|
@ -122,7 +122,7 @@ def main():
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
# torch.backends.cudnn.deterministic = True
|
# torch.backends.cudnn.deterministic = True
|
||||||
#torch.autograd.set_detect_anomaly(True)
|
# torch.autograd.set_detect_anomaly(True)
|
||||||
|
|
||||||
# Save the compiled opt dict to the global loaded_options variable.
|
# Save the compiled opt dict to the global loaded_options variable.
|
||||||
util.loaded_options = opt
|
util.loaded_options = opt
|
||||||
|
@ -170,6 +170,8 @@ def main():
|
||||||
else:
|
else:
|
||||||
current_step = -1 if 'start_step' not in opt.keys() else opt['start_step']
|
current_step = -1 if 'start_step' not in opt.keys() else opt['start_step']
|
||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
|
if 'force_start_step' in opt.keys():
|
||||||
|
current_step = opt['force_start_step']
|
||||||
|
|
||||||
#### training
|
#### training
|
||||||
logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
|
logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
|
||||||
|
@ -185,10 +187,6 @@ def main():
|
||||||
print("Data fetch: %f" % (time() - _t))
|
print("Data fetch: %f" % (time() - _t))
|
||||||
_t = time()
|
_t = time()
|
||||||
|
|
||||||
#tb_logger.add_graph(model.netsG['generator'].module, [train_data['LQ'].to('cuda'),
|
|
||||||
# train_data['lq_fullsize_ref'].float().to('cuda'),
|
|
||||||
# train_data['lq_center'].to('cuda')])
|
|
||||||
|
|
||||||
current_step += 1
|
current_step += 1
|
||||||
if current_step > total_iters:
|
if current_step > total_iters:
|
||||||
break
|
break
|
||||||
|
@ -241,9 +239,6 @@ def main():
|
||||||
#### validation
|
#### validation
|
||||||
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
|
if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
|
||||||
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan', 'extensibletrainer'] and rank <= 0: # image restoration validation
|
if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan', 'extensibletrainer'] and rank <= 0: # image restoration validation
|
||||||
model.force_restore_swapout()
|
|
||||||
val_batch_sz = 1 if 'batch_size' not in opt['datasets']['val'].keys() else opt['datasets']['val']['batch_size']
|
|
||||||
# does not support multi-GPU validation
|
|
||||||
avg_psnr = 0.
|
avg_psnr = 0.
|
||||||
avg_fea_loss = 0.
|
avg_fea_loss = 0.
|
||||||
idx = 0
|
idx = 0
|
||||||
|
@ -263,23 +258,22 @@ def main():
|
||||||
if visuals is None:
|
if visuals is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if colab_mode:
|
||||||
|
colab_imgs_to_copy.append(save_img_path)
|
||||||
|
|
||||||
|
# calculate PSNR
|
||||||
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
|
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
|
||||||
#gt_img = util.tensor2img(visuals['GT'][b]) # uint8
|
gt_img = util.tensor2img(visuals['GT'][b]) # uint8
|
||||||
|
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
||||||
|
avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
||||||
|
|
||||||
|
# calculate fea loss
|
||||||
|
avg_fea_loss += model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
|
||||||
|
|
||||||
# Save SR images for reference
|
# Save SR images for reference
|
||||||
img_base_name = '{:s}_{:d}.png'.format(img_name, current_step)
|
img_base_name = '{:s}_{:d}.png'.format(img_name, current_step)
|
||||||
save_img_path = os.path.join(img_dir, img_base_name)
|
save_img_path = os.path.join(img_dir, img_base_name)
|
||||||
util.save_img(sr_img, save_img_path)
|
util.save_img(sr_img, save_img_path)
|
||||||
if colab_mode:
|
|
||||||
colab_imgs_to_copy.append(save_img_path)
|
|
||||||
|
|
||||||
# calculate PSNR (Naw - don't do that. PSNR sucks)
|
|
||||||
#sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
|
||||||
#avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
|
||||||
#pbar.update('Test {}'.format(img_name))
|
|
||||||
|
|
||||||
# calculate fea loss
|
|
||||||
avg_fea_loss += model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
|
|
||||||
|
|
||||||
if colab_mode:
|
if colab_mode:
|
||||||
util.copy_files_to_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
|
util.copy_files_to_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
|
||||||
|
@ -293,7 +287,7 @@ def main():
|
||||||
logger.info('# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(avg_psnr, avg_fea_loss))
|
logger.info('# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(avg_psnr, avg_fea_loss))
|
||||||
# tensorboard logger
|
# tensorboard logger
|
||||||
if opt['use_tb_logger'] and 'debug' not in opt['name'] and rank <= 0:
|
if opt['use_tb_logger'] and 'debug' not in opt['name'] and rank <= 0:
|
||||||
#tb_logger.add_scalar('val_psnr', avg_psnr, current_step)
|
tb_logger.add_scalar('val_psnr', avg_psnr, current_step)
|
||||||
tb_logger.add_scalar('val_fea', avg_fea_loss, current_step)
|
tb_logger.add_scalar('val_fea', avg_fea_loss, current_step)
|
||||||
|
|
||||||
if rank <= 0:
|
if rank <= 0:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user