From fc55bdb24e10c18566249eec040f3852428da998 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 12 Nov 2020 15:45:25 -0700 Subject: [PATCH] Mods to how wandb are integrated --- codes/train.py | 15 ++++++++------- codes/train2.py | 15 ++++++++------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/codes/train.py b/codes/train.py index 50f6bf14..ec5211f6 100644 --- a/codes/train.py +++ b/codes/train.py @@ -33,11 +33,6 @@ class Trainer: self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'] else True self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'] else True - #### wandb init - if opt['wandb']: - import wandb - wandb.init(project=opt['name']) - #### loading resume state if exists if opt['path'].get('resume_state', None): # distributed resuming: all load into default GPU @@ -81,6 +76,12 @@ class Trainer: opt = option.dict_to_nonedict(opt) self.opt = opt + #### wandb init + if opt['wandb']: + import wandb + os.makedirs(os.path.join(opt['path']['log'], 'wandb'), exist_ok=True) + wandb.init(project=opt['name'], dir=opt['path']['log']) + #### random seed seed = opt['train']['manual_seed'] if seed is None: @@ -156,7 +157,7 @@ class Trainer: if self._profile: print("Update LR: %f" % (time() - _t)) _t = time() - self.model.feed_data(train_data) + self.model.feed_data(train_data, self.current_step) self.model.optimize_parameters(self.current_step) if self._profile: print("Model feed + step: %f" % (time() - _t)) @@ -273,7 +274,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_adalatent_mi1_rrdb4x_6bl_pyrrrdb_disc.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_faster.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/train2.py b/codes/train2.py index ab63c7b8..ad4ddb1c 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -33,11 +33,6 @@ class Trainer: self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'] else True self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'] else True - #### wandb init - if opt['wandb']: - import wandb - wandb.init(project=opt['name']) - #### loading resume state if exists if opt['path'].get('resume_state', None): # distributed resuming: all load into default GPU @@ -81,6 +76,12 @@ class Trainer: opt = option.dict_to_nonedict(opt) self.opt = opt + #### wandb init + if opt['wandb']: + import wandb + os.makedirs(os.path.join(opt['path']['log'], 'wandb'), exist_ok=True) + wandb.init(project=opt['name'], dir=opt['path']['log']) + #### random seed seed = opt['train']['manual_seed'] if seed is None: @@ -156,7 +157,7 @@ class Trainer: if self._profile: print("Update LR: %f" % (time() - _t)) _t = time() - self.model.feed_data(train_data) + self.model.feed_data(train_data, self.current_step) self.model.optimize_parameters(self.current_step) if self._profile: print("Model feed + step: %f" % (time() - _t)) @@ -273,7 +274,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_adalatent_mi1_rrdb4x_6bl_nolatent.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()