Mods to how wandb are integrated
This commit is contained in:
parent
44a19cd37c
commit
fc55bdb24e
|
@ -33,11 +33,6 @@ class Trainer:
|
|||
self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'] else True
|
||||
self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'] else True
|
||||
|
||||
#### wandb init
|
||||
if opt['wandb']:
|
||||
import wandb
|
||||
wandb.init(project=opt['name'])
|
||||
|
||||
#### loading resume state if exists
|
||||
if opt['path'].get('resume_state', None):
|
||||
# distributed resuming: all load into default GPU
|
||||
|
@ -81,6 +76,12 @@ class Trainer:
|
|||
opt = option.dict_to_nonedict(opt)
|
||||
self.opt = opt
|
||||
|
||||
#### wandb init
|
||||
if opt['wandb']:
|
||||
import wandb
|
||||
os.makedirs(os.path.join(opt['path']['log'], 'wandb'), exist_ok=True)
|
||||
wandb.init(project=opt['name'], dir=opt['path']['log'])
|
||||
|
||||
#### random seed
|
||||
seed = opt['train']['manual_seed']
|
||||
if seed is None:
|
||||
|
@ -156,7 +157,7 @@ class Trainer:
|
|||
if self._profile:
|
||||
print("Update LR: %f" % (time() - _t))
|
||||
_t = time()
|
||||
self.model.feed_data(train_data)
|
||||
self.model.feed_data(train_data, self.current_step)
|
||||
self.model.optimize_parameters(self.current_step)
|
||||
if self._profile:
|
||||
print("Model feed + step: %f" % (time() - _t))
|
||||
|
@ -273,7 +274,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_adalatent_mi1_rrdb4x_6bl_pyrrrdb_disc.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_faster.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -33,11 +33,6 @@ class Trainer:
|
|||
self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'] else True
|
||||
self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'] else True
|
||||
|
||||
#### wandb init
|
||||
if opt['wandb']:
|
||||
import wandb
|
||||
wandb.init(project=opt['name'])
|
||||
|
||||
#### loading resume state if exists
|
||||
if opt['path'].get('resume_state', None):
|
||||
# distributed resuming: all load into default GPU
|
||||
|
@ -81,6 +76,12 @@ class Trainer:
|
|||
opt = option.dict_to_nonedict(opt)
|
||||
self.opt = opt
|
||||
|
||||
#### wandb init
|
||||
if opt['wandb']:
|
||||
import wandb
|
||||
os.makedirs(os.path.join(opt['path']['log'], 'wandb'), exist_ok=True)
|
||||
wandb.init(project=opt['name'], dir=opt['path']['log'])
|
||||
|
||||
#### random seed
|
||||
seed = opt['train']['manual_seed']
|
||||
if seed is None:
|
||||
|
@ -156,7 +157,7 @@ class Trainer:
|
|||
if self._profile:
|
||||
print("Update LR: %f" % (time() - _t))
|
||||
_t = time()
|
||||
self.model.feed_data(train_data)
|
||||
self.model.feed_data(train_data, self.current_step)
|
||||
self.model.optimize_parameters(self.current_step)
|
||||
if self._profile:
|
||||
print("Model feed + step: %f" % (time() - _t))
|
||||
|
@ -273,7 +274,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_adalatent_mi1_rrdb4x_6bl_nolatent.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
Loading…
Reference in New Issue
Block a user