Enable usage of wandb
This commit is contained in:
parent
1c065c41b4
commit
88f349bdf1
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -4,6 +4,7 @@ tb_logger/*
|
|||
datasets/*
|
||||
options/*
|
||||
codes/*.txt
|
||||
codes/wandb/*
|
||||
.vscode
|
||||
|
||||
*.html
|
||||
|
|
|
@ -67,6 +67,9 @@ class ExtensibleTrainer(BaseModel):
|
|||
|
||||
if not net['trainable']:
|
||||
new_net.eval()
|
||||
if net['wandb_debug']:
|
||||
import wandb
|
||||
wandb.watch(new_net, log='all', log_freq=3)
|
||||
|
||||
# Initialize the train/eval steps
|
||||
self.step_names = []
|
||||
|
|
|
@ -33,6 +33,11 @@ class Trainer:
|
|||
self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'] else True
|
||||
self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'] else True
|
||||
|
||||
#### wandb init
|
||||
if opt['wandb']:
|
||||
import wandb
|
||||
wandb.init(project=opt['name'])
|
||||
|
||||
#### loading resume state if exists
|
||||
if opt['path'].get('resume_state', None):
|
||||
# distributed resuming: all load into default GPU
|
||||
|
@ -174,6 +179,9 @@ class Trainer:
|
|||
# tensorboard logger
|
||||
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
||||
self.tb_logger.add_scalar(k, v, self.current_step)
|
||||
if opt['wandb']:
|
||||
import wandb
|
||||
wandb.log(logs)
|
||||
self.logger.info(message)
|
||||
|
||||
#### save models and training states
|
||||
|
@ -265,7 +273,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_adalatent_mi1_rrdb4x_6bl.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_adalatent_mi1_rrdb4x_6bl_pyrrrdb_disc.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -13,8 +13,7 @@ from data import create_dataloader, create_dataset
|
|||
from models.ExtensibleTrainer import ExtensibleTrainer
|
||||
from time import time
|
||||
|
||||
class Trainer:
|
||||
def init_dist(self, backend, **kwargs):
|
||||
def init_dist(backend, **kwargs):
|
||||
# These packages have globals that screw with Windows, so only import them if needed.
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
@ -22,31 +21,22 @@ class Trainer:
|
|||
"""initialization for distributed training"""
|
||||
if mp.get_start_method(allow_none=True) != 'spawn':
|
||||
mp.set_start_method('spawn')
|
||||
self.rank = int(os.environ['RANK'])
|
||||
rank = int(os.environ['RANK'])
|
||||
num_gpus = torch.cuda.device_count()
|
||||
torch.cuda.set_device(self.rank % num_gpus)
|
||||
torch.cuda.set_device(rank % num_gpus)
|
||||
dist.init_process_group(backend=backend, **kwargs)
|
||||
|
||||
class Trainer:
|
||||
|
||||
def init(self, opt, launcher, all_networks={}):
|
||||
self._profile = False
|
||||
self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'] else True
|
||||
self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'] else True
|
||||
|
||||
#### distributed training settings
|
||||
if len(opt['gpu_ids']) == 1 and torch.cuda.device_count() > 1:
|
||||
gpu = input(
|
||||
'I noticed you have multiple GPUs. Starting two jobs on the same GPU sucks. Please confirm which GPU'
|
||||
'you want to use. Press enter to use the specified one [%s]' % (opt['gpu_ids']))
|
||||
if gpu:
|
||||
opt['gpu_ids'] = [int(gpu)]
|
||||
if launcher == 'none': # disabled distributed training
|
||||
opt['dist'] = False
|
||||
self.rank = -1
|
||||
print('Disabled distributed training.')
|
||||
|
||||
else:
|
||||
opt['dist'] = True
|
||||
self.init_dist()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
self.rank = torch.distributed.get_rank()
|
||||
#### wandb init
|
||||
if opt['wandb']:
|
||||
import wandb
|
||||
wandb.init(project=opt['name'])
|
||||
|
||||
#### loading resume state if exists
|
||||
if opt['path'].get('resume_state', None):
|
||||
|
@ -115,11 +105,11 @@ class Trainer:
|
|||
total_iters = int(opt['train']['niter'])
|
||||
self.total_epochs = int(math.ceil(total_iters / train_size))
|
||||
if opt['dist']:
|
||||
train_sampler = DistIterSampler(self.train_set, world_size, self.rank, dataset_ratio)
|
||||
self.train_sampler = DistIterSampler(self.train_set, self.world_size, self.rank, dataset_ratio)
|
||||
self.total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
|
||||
else:
|
||||
train_sampler = None
|
||||
self.train_loader = create_dataloader(self.train_set, dataset_opt, opt, train_sampler)
|
||||
self.train_sampler = None
|
||||
self.train_loader = create_dataloader(self.train_set, dataset_opt, opt, self.train_sampler)
|
||||
if self.rank <= 0:
|
||||
self.logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
|
||||
len(self.train_set), train_size))
|
||||
|
@ -157,8 +147,6 @@ class Trainer:
|
|||
print("Data fetch: %f" % (time() - _t))
|
||||
_t = time()
|
||||
|
||||
#self.tb_logger.add_graph(self.model.netsG['generator'].module, input_to_model=torch.randn((1,3,32,32), device='cuda:0'))
|
||||
|
||||
opt = self.opt
|
||||
self.current_step += 1
|
||||
#### update learning rate
|
||||
|
@ -191,6 +179,9 @@ class Trainer:
|
|||
# tensorboard logger
|
||||
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
||||
self.tb_logger.add_scalar(k, v, self.current_step)
|
||||
if opt['wandb']:
|
||||
import wandb
|
||||
wandb.log(logs)
|
||||
self.logger.info(message)
|
||||
|
||||
#### save models and training states
|
||||
|
@ -216,8 +207,8 @@ class Trainer:
|
|||
val_tqdm = tqdm(self.val_loader)
|
||||
for val_data in val_tqdm:
|
||||
idx += 1
|
||||
for b in range(len(val_data['LQ_path'])):
|
||||
img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][b]))[0]
|
||||
for b in range(len(val_data['GT_path'])):
|
||||
img_name = os.path.splitext(os.path.basename(val_data['GT_path'][b]))[0]
|
||||
img_dir = os.path.join(opt['path']['val_images'], img_name)
|
||||
util.mkdir(img_dir)
|
||||
|
||||
|
@ -228,13 +219,15 @@ class Trainer:
|
|||
if visuals is None:
|
||||
continue
|
||||
|
||||
# calculate PSNR
|
||||
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
|
||||
# calculate PSNR
|
||||
if self.val_compute_psnr:
|
||||
gt_img = util.tensor2img(visuals['GT'][b]) # uint8
|
||||
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
||||
avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
||||
|
||||
# calculate fea loss
|
||||
if self.val_compute_fea:
|
||||
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
|
||||
|
||||
# Save SR images for reference
|
||||
|
@ -280,10 +273,24 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_adalatent_mi1_rrdb4x_6bl_resdisc.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_adalatent_mi1_rrdb4x_6bl_nolatent.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
trainer = Trainer()
|
||||
|
||||
#### distributed training settings
|
||||
if args.launcher == 'none': # disabled distributed training
|
||||
opt['dist'] = False
|
||||
trainer.rank = -1
|
||||
print('Disabled distributed training.')
|
||||
|
||||
else:
|
||||
opt['dist'] = True
|
||||
init_dist('nccl')
|
||||
trainer.world_size = torch.distributed.get_world_size()
|
||||
trainer.rank = torch.distributed.get_rank()
|
||||
|
||||
trainer.init(opt, args.launcher)
|
||||
trainer.do_training()
|
||||
|
|
Loading…
Reference in New Issue
Block a user