diff --git a/codes/models/base_model.py b/codes/models/base_model.py index 68959b3d..1b8348d1 100644 --- a/codes/models/base_model.py +++ b/codes/models/base_model.py @@ -3,6 +3,7 @@ from collections import OrderedDict import torch import torch.nn as nn from torch.nn.parallel import DistributedDataParallel +import utils.util class BaseModel(): @@ -84,6 +85,9 @@ class BaseModel(): # Also save to the 'alt_path' which is useful for caching to Google Drive in colab, for example. if 'alt_path' in self.opt['path'].keys(): torch.save(state_dict, os.path.join(self.opt['path']['alt_path'], save_filename)) + if self.opt['colab_mode']: + utils.util.copy_files_to_server(self.opt['ssh_server'], self.opt['ssh_username'], self.opt['ssh_password'], + save_path, os.path.join(self.opt['remote_path'], 'models', save_filename)) return save_path def load_network(self, load_path, network, strict=True): @@ -111,6 +115,9 @@ class BaseModel(): # Also save to the 'alt_path' which is useful for caching to Google Drive in colab, for example. if 'alt_path' in self.opt['path'].keys(): torch.save(state, os.path.join(self.opt['path']['alt_path'], 'latest.state')) + if self.opt['colab_mode']: + utils.util.copy_files_to_server(self.opt['ssh_server'], self.opt['ssh_username'], self.opt['ssh_password'], + save_path, os.path.join(self.opt['remote_path'], 'training_state', save_filename)) def resume_training(self, resume_state): """Resume the optimizers and schedulers for training""" diff --git a/codes/requirements.txt b/codes/requirements.txt index a107c0e7..c28aebf0 100644 --- a/codes/requirements.txt +++ b/codes/requirements.txt @@ -4,3 +4,5 @@ lmdb pyyaml tb-nightly future +scp +tqdm diff --git a/codes/train.py b/codes/train.py index c1cdc062..c2173eca 100644 --- a/codes/train.py +++ b/codes/train.py @@ -30,13 +30,31 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/finetune_hoh_resgen_xl_blurring.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imset_pre_rrdb.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() opt = option.parse(args.opt, is_train=True) + colab_mode = False if 'colab_mode' not in opt.keys() else opt['colab_mode'] + if colab_mode: + # Check the configuration of the remote server. Expect models, resume_state, and val_images directories to be there. + # Each one should have a TEST file in it. + util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], + os.path.join(opt['remote_path'], 'training_state', "TEST")) + util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], + os.path.join(opt['remote_path'], 'models', "TEST")) + util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], + os.path.join(opt['remote_path'], 'val_images', "TEST")) + # Load the state and models needed from the remote server. + if opt['path']['resume_state']: + util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], os.path.join(opt['remote_path'], 'training_state', opt['path']['resume_state'])) + if opt['path']['pretrain_model_G']: + util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], os.path.join(opt['remote_path'], 'models', opt['path']['pretrain_model_G'])) + if opt['path']['pretrain_model_D']: + util.get_files_from_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], os.path.join(opt['remote_path'], 'models', opt['path']['pretrain_model_D'])) + #### distributed training settings if args.launcher == 'none': # disabled distributed training opt['dist'] = False @@ -190,6 +208,7 @@ def main(): pbar = util.ProgressBar(len(val_loader)) avg_psnr = 0. idx = 0 + colab_imgs_to_copy = [] for val_data in val_loader: idx += 1 if idx >= 20: @@ -207,15 +226,22 @@ def main(): gt_img = util.tensor2img(visuals['GT']) # uint8 # Save SR images for reference - save_img_path = os.path.join(img_dir, - '{:s}_{:d}.png'.format(img_name, current_step)) + img_base_name = '{:s}_{:d}.png'.format(img_name, current_step) + save_img_path = os.path.join(img_dir, img_base_name) util.save_img(sr_img, save_img_path) + if colab_mode: + colab_imgs_to_copy.append(save_img_path) # calculate PSNR sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) avg_psnr += util.calculate_psnr(sr_img, gt_img) pbar.update('Test {}'.format(img_name)) + if colab_mode: + util.copy_files_to_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], + colab_imgs_to_copy, + os.path.join(opt['remote_path'], 'val_images', img_base_name)) + avg_psnr = avg_psnr / idx # log diff --git a/codes/utils/util.py b/codes/utils/util.py index a8924fa8..fe3d39a8 100644 --- a/codes/utils/util.py +++ b/codes/utils/util.py @@ -12,6 +12,8 @@ import cv2 import torch from torchvision.utils import make_grid from shutil import get_terminal_size +import scp +import paramiko import yaml try: @@ -90,6 +92,21 @@ def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tof sh.setFormatter(formatter) lg.addHandler(sh) +def copy_files_to_server(host, user, password, files, remote_path): + client = paramiko.SSHClient() + client.load_system_host_keys() + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + client.connect(host, username=user, password=password) + scpclient = scp.SCPClient(client.get_transport()) + scpclient.put(files, remote_path) + +def get_files_from_server(host, user, password, remote_path, local_path): + client = paramiko.SSHClient() + client.load_system_host_keys() + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + client.connect(host, username=user, password=password) + scpclient = scp.SCPClient(client.get_transport()) + scpclient.get(remote_path, local_path) #################### # image convert