Mods to how wandb are integrated

This commit is contained in:
James Betker 2020-11-12 15:45:25 -07:00
parent 44a19cd37c
commit fc55bdb24e
2 changed files with 16 additions and 14 deletions

View File

@ -33,11 +33,6 @@ class Trainer:
self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'] else True
self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'] else True
#### wandb init
if opt['wandb']:
import wandb
wandb.init(project=opt['name'])
#### loading resume state if exists
if opt['path'].get('resume_state', None):
# distributed resuming: all load into default GPU
@ -81,6 +76,12 @@ class Trainer:
opt = option.dict_to_nonedict(opt)
self.opt = opt
#### wandb init
if opt['wandb']:
import wandb
os.makedirs(os.path.join(opt['path']['log'], 'wandb'), exist_ok=True)
wandb.init(project=opt['name'], dir=opt['path']['log'])
#### random seed
seed = opt['train']['manual_seed']
if seed is None:
@ -156,7 +157,7 @@ class Trainer:
if self._profile:
print("Update LR: %f" % (time() - _t))
_t = time()
self.model.feed_data(train_data)
self.model.feed_data(train_data, self.current_step)
self.model.optimize_parameters(self.current_step)
if self._profile:
print("Model feed + step: %f" % (time() - _t))
@ -273,7 +274,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_adalatent_mi1_rrdb4x_6bl_pyrrrdb_disc.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_faster.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()

View File

@ -33,11 +33,6 @@ class Trainer:
self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'] else True
self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'] else True
#### wandb init
if opt['wandb']:
import wandb
wandb.init(project=opt['name'])
#### loading resume state if exists
if opt['path'].get('resume_state', None):
# distributed resuming: all load into default GPU
@ -81,6 +76,12 @@ class Trainer:
opt = option.dict_to_nonedict(opt)
self.opt = opt
#### wandb init
if opt['wandb']:
import wandb
os.makedirs(os.path.join(opt['path']['log'], 'wandb'), exist_ok=True)
wandb.init(project=opt['name'], dir=opt['path']['log'])
#### random seed
seed = opt['train']['manual_seed']
if seed is None:
@ -156,7 +157,7 @@ class Trainer:
if self._profile:
print("Update LR: %f" % (time() - _t))
_t = time()
self.model.feed_data(train_data)
self.model.feed_data(train_data, self.current_step)
self.model.optimize_parameters(self.current_step)
if self._profile:
print("Model feed + step: %f" % (time() - _t))
@ -273,7 +274,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_adalatent_mi1_rrdb4x_6bl_nolatent.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()