forked from mrq/DL-Art-School
Introduce (untested) colab mode
This commit is contained in:
parent
a38dd62489
commit
f1a1fd14b1
|
@ -3,6 +3,7 @@ from collections import OrderedDict
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn.parallel import DistributedDataParallel
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
|
import utils.util
|
||||||
|
|
||||||
|
|
||||||
class BaseModel():
|
class BaseModel():
|
||||||
|
@ -84,6 +85,9 @@ class BaseModel():
|
||||||
# Also save to the 'alt_path' which is useful for caching to Google Drive in colab, for example.
|
# Also save to the 'alt_path' which is useful for caching to Google Drive in colab, for example.
|
||||||
if 'alt_path' in self.opt['path'].keys():
|
if 'alt_path' in self.opt['path'].keys():
|
||||||
torch.save(state_dict, os.path.join(self.opt['path']['alt_path'], save_filename))
|
torch.save(state_dict, os.path.join(self.opt['path']['alt_path'], save_filename))
|
||||||
|
if self.opt['colab_mode']:
|
||||||
|
utils.util.copy_files_to_server(self.opt['ssh_server'], self.opt['ssh_username'], self.opt['ssh_password'],
|
||||||
|
save_path, os.path.join(self.opt['remote_path'], 'models', save_filename))
|
||||||
return save_path
|
return save_path
|
||||||
|
|
||||||
def load_network(self, load_path, network, strict=True):
|
def load_network(self, load_path, network, strict=True):
|
||||||
|
@ -111,6 +115,9 @@ class BaseModel():
|
||||||
# Also save to the 'alt_path' which is useful for caching to Google Drive in colab, for example.
|
# Also save to the 'alt_path' which is useful for caching to Google Drive in colab, for example.
|
||||||
if 'alt_path' in self.opt['path'].keys():
|
if 'alt_path' in self.opt['path'].keys():
|
||||||
torch.save(state, os.path.join(self.opt['path']['alt_path'], 'latest.state'))
|
torch.save(state, os.path.join(self.opt['path']['alt_path'], 'latest.state'))
|
||||||
|
if self.opt['colab_mode']:
|
||||||
|
utils.util.copy_files_to_server(self.opt['ssh_server'], self.opt['ssh_username'], self.opt['ssh_password'],
|
||||||
|
save_path, os.path.join(self.opt['remote_path'], 'training_state', save_filename))
|
||||||
|
|
||||||
def resume_training(self, resume_state):
|
def resume_training(self, resume_state):
|
||||||
"""Resume the optimizers and schedulers for training"""
|
"""Resume the optimizers and schedulers for training"""
|
||||||
|
|
|
@ -4,3 +4,5 @@ lmdb
|
||||||
pyyaml
|
pyyaml
|
||||||
tb-nightly
|
tb-nightly
|
||||||
future
|
future
|
||||||
|
scp
|
||||||
|
tqdm
|
||||||
|
|
|
@ -30,13 +30,31 @@ def init_dist(backend='nccl', **kwargs):
|
||||||
def main():
|
def main():
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/finetune_hoh_resgen_xl_blurring.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imset_pre_rrdb.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
opt = option.parse(args.opt, is_train=True)
|
opt = option.parse(args.opt, is_train=True)
|
||||||
|
|
||||||
|
colab_mode = False if 'colab_mode' not in opt.keys() else opt['colab_mode']
|
||||||
|
if colab_mode:
|
||||||
|
# Check the configuration of the remote server. Expect models, resume_state, and val_images directories to be there.
|
||||||
|
# Each one should have a TEST file in it.
|
||||||
|
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
|
||||||
|
os.path.join(opt['remote_path'], 'training_state', "TEST"))
|
||||||
|
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
|
||||||
|
os.path.join(opt['remote_path'], 'models', "TEST"))
|
||||||
|
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
|
||||||
|
os.path.join(opt['remote_path'], 'val_images', "TEST"))
|
||||||
|
# Load the state and models needed from the remote server.
|
||||||
|
if opt['path']['resume_state']:
|
||||||
|
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], os.path.join(opt['remote_path'], 'training_state', opt['path']['resume_state']))
|
||||||
|
if opt['path']['pretrain_model_G']:
|
||||||
|
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], os.path.join(opt['remote_path'], 'models', opt['path']['pretrain_model_G']))
|
||||||
|
if opt['path']['pretrain_model_D']:
|
||||||
|
util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], os.path.join(opt['remote_path'], 'models', opt['path']['pretrain_model_D']))
|
||||||
|
|
||||||
#### distributed training settings
|
#### distributed training settings
|
||||||
if args.launcher == 'none': # disabled distributed training
|
if args.launcher == 'none': # disabled distributed training
|
||||||
opt['dist'] = False
|
opt['dist'] = False
|
||||||
|
@ -190,6 +208,7 @@ def main():
|
||||||
pbar = util.ProgressBar(len(val_loader))
|
pbar = util.ProgressBar(len(val_loader))
|
||||||
avg_psnr = 0.
|
avg_psnr = 0.
|
||||||
idx = 0
|
idx = 0
|
||||||
|
colab_imgs_to_copy = []
|
||||||
for val_data in val_loader:
|
for val_data in val_loader:
|
||||||
idx += 1
|
idx += 1
|
||||||
if idx >= 20:
|
if idx >= 20:
|
||||||
|
@ -207,15 +226,22 @@ def main():
|
||||||
gt_img = util.tensor2img(visuals['GT']) # uint8
|
gt_img = util.tensor2img(visuals['GT']) # uint8
|
||||||
|
|
||||||
# Save SR images for reference
|
# Save SR images for reference
|
||||||
save_img_path = os.path.join(img_dir,
|
img_base_name = '{:s}_{:d}.png'.format(img_name, current_step)
|
||||||
'{:s}_{:d}.png'.format(img_name, current_step))
|
save_img_path = os.path.join(img_dir, img_base_name)
|
||||||
util.save_img(sr_img, save_img_path)
|
util.save_img(sr_img, save_img_path)
|
||||||
|
if colab_mode:
|
||||||
|
colab_imgs_to_copy.append(save_img_path)
|
||||||
|
|
||||||
# calculate PSNR
|
# calculate PSNR
|
||||||
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
||||||
avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
||||||
pbar.update('Test {}'.format(img_name))
|
pbar.update('Test {}'.format(img_name))
|
||||||
|
|
||||||
|
if colab_mode:
|
||||||
|
util.copy_files_to_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'],
|
||||||
|
colab_imgs_to_copy,
|
||||||
|
os.path.join(opt['remote_path'], 'val_images', img_base_name))
|
||||||
|
|
||||||
avg_psnr = avg_psnr / idx
|
avg_psnr = avg_psnr / idx
|
||||||
|
|
||||||
# log
|
# log
|
||||||
|
|
|
@ -12,6 +12,8 @@ import cv2
|
||||||
import torch
|
import torch
|
||||||
from torchvision.utils import make_grid
|
from torchvision.utils import make_grid
|
||||||
from shutil import get_terminal_size
|
from shutil import get_terminal_size
|
||||||
|
import scp
|
||||||
|
import paramiko
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
try:
|
try:
|
||||||
|
@ -90,6 +92,21 @@ def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tof
|
||||||
sh.setFormatter(formatter)
|
sh.setFormatter(formatter)
|
||||||
lg.addHandler(sh)
|
lg.addHandler(sh)
|
||||||
|
|
||||||
|
def copy_files_to_server(host, user, password, files, remote_path):
|
||||||
|
client = paramiko.SSHClient()
|
||||||
|
client.load_system_host_keys()
|
||||||
|
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||||
|
client.connect(host, username=user, password=password)
|
||||||
|
scpclient = scp.SCPClient(client.get_transport())
|
||||||
|
scpclient.put(files, remote_path)
|
||||||
|
|
||||||
|
def get_files_from_server(host, user, password, remote_path, local_path):
|
||||||
|
client = paramiko.SSHClient()
|
||||||
|
client.load_system_host_keys()
|
||||||
|
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||||
|
client.connect(host, username=user, password=password)
|
||||||
|
scpclient = scp.SCPClient(client.get_transport())
|
||||||
|
scpclient.get(remote_path, local_path)
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# image convert
|
# image convert
|
||||||
|
|
Loading…
Reference in New Issue
Block a user