Enable usage of wandb
This commit is contained in:
parent
1c065c41b4
commit
88f349bdf1
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -4,6 +4,7 @@ tb_logger/*
|
||||||
datasets/*
|
datasets/*
|
||||||
options/*
|
options/*
|
||||||
codes/*.txt
|
codes/*.txt
|
||||||
|
codes/wandb/*
|
||||||
.vscode
|
.vscode
|
||||||
|
|
||||||
*.html
|
*.html
|
||||||
|
|
|
@ -67,6 +67,9 @@ class ExtensibleTrainer(BaseModel):
|
||||||
|
|
||||||
if not net['trainable']:
|
if not net['trainable']:
|
||||||
new_net.eval()
|
new_net.eval()
|
||||||
|
if net['wandb_debug']:
|
||||||
|
import wandb
|
||||||
|
wandb.watch(new_net, log='all', log_freq=3)
|
||||||
|
|
||||||
# Initialize the train/eval steps
|
# Initialize the train/eval steps
|
||||||
self.step_names = []
|
self.step_names = []
|
||||||
|
|
|
@ -33,6 +33,11 @@ class Trainer:
|
||||||
self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'] else True
|
self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'] else True
|
||||||
self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'] else True
|
self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'] else True
|
||||||
|
|
||||||
|
#### wandb init
|
||||||
|
if opt['wandb']:
|
||||||
|
import wandb
|
||||||
|
wandb.init(project=opt['name'])
|
||||||
|
|
||||||
#### loading resume state if exists
|
#### loading resume state if exists
|
||||||
if opt['path'].get('resume_state', None):
|
if opt['path'].get('resume_state', None):
|
||||||
# distributed resuming: all load into default GPU
|
# distributed resuming: all load into default GPU
|
||||||
|
@ -174,6 +179,9 @@ class Trainer:
|
||||||
# tensorboard logger
|
# tensorboard logger
|
||||||
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
||||||
self.tb_logger.add_scalar(k, v, self.current_step)
|
self.tb_logger.add_scalar(k, v, self.current_step)
|
||||||
|
if opt['wandb']:
|
||||||
|
import wandb
|
||||||
|
wandb.log(logs)
|
||||||
self.logger.info(message)
|
self.logger.info(message)
|
||||||
|
|
||||||
#### save models and training states
|
#### save models and training states
|
||||||
|
@ -265,7 +273,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_adalatent_mi1_rrdb4x_6bl.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_adalatent_mi1_rrdb4x_6bl_pyrrrdb_disc.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -13,40 +13,30 @@ from data import create_dataloader, create_dataset
|
||||||
from models.ExtensibleTrainer import ExtensibleTrainer
|
from models.ExtensibleTrainer import ExtensibleTrainer
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
class Trainer:
|
def init_dist(backend, **kwargs):
|
||||||
def init_dist(self, backend, **kwargs):
|
# These packages have globals that screw with Windows, so only import them if needed.
|
||||||
# These packages have globals that screw with Windows, so only import them if needed.
|
import torch.distributed as dist
|
||||||
import torch.distributed as dist
|
import torch.multiprocessing as mp
|
||||||
import torch.multiprocessing as mp
|
|
||||||
|
|
||||||
"""initialization for distributed training"""
|
"""initialization for distributed training"""
|
||||||
if mp.get_start_method(allow_none=True) != 'spawn':
|
if mp.get_start_method(allow_none=True) != 'spawn':
|
||||||
mp.set_start_method('spawn')
|
mp.set_start_method('spawn')
|
||||||
self.rank = int(os.environ['RANK'])
|
rank = int(os.environ['RANK'])
|
||||||
num_gpus = torch.cuda.device_count()
|
num_gpus = torch.cuda.device_count()
|
||||||
torch.cuda.set_device(self.rank % num_gpus)
|
torch.cuda.set_device(rank % num_gpus)
|
||||||
dist.init_process_group(backend=backend, **kwargs)
|
dist.init_process_group(backend=backend, **kwargs)
|
||||||
|
|
||||||
|
class Trainer:
|
||||||
|
|
||||||
def init(self, opt, launcher, all_networks={}):
|
def init(self, opt, launcher, all_networks={}):
|
||||||
self._profile = False
|
self._profile = False
|
||||||
|
self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'] else True
|
||||||
|
self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'] else True
|
||||||
|
|
||||||
#### distributed training settings
|
#### wandb init
|
||||||
if len(opt['gpu_ids']) == 1 and torch.cuda.device_count() > 1:
|
if opt['wandb']:
|
||||||
gpu = input(
|
import wandb
|
||||||
'I noticed you have multiple GPUs. Starting two jobs on the same GPU sucks. Please confirm which GPU'
|
wandb.init(project=opt['name'])
|
||||||
'you want to use. Press enter to use the specified one [%s]' % (opt['gpu_ids']))
|
|
||||||
if gpu:
|
|
||||||
opt['gpu_ids'] = [int(gpu)]
|
|
||||||
if launcher == 'none': # disabled distributed training
|
|
||||||
opt['dist'] = False
|
|
||||||
self.rank = -1
|
|
||||||
print('Disabled distributed training.')
|
|
||||||
|
|
||||||
else:
|
|
||||||
opt['dist'] = True
|
|
||||||
self.init_dist()
|
|
||||||
world_size = torch.distributed.get_world_size()
|
|
||||||
self.rank = torch.distributed.get_rank()
|
|
||||||
|
|
||||||
#### loading resume state if exists
|
#### loading resume state if exists
|
||||||
if opt['path'].get('resume_state', None):
|
if opt['path'].get('resume_state', None):
|
||||||
|
@ -115,11 +105,11 @@ class Trainer:
|
||||||
total_iters = int(opt['train']['niter'])
|
total_iters = int(opt['train']['niter'])
|
||||||
self.total_epochs = int(math.ceil(total_iters / train_size))
|
self.total_epochs = int(math.ceil(total_iters / train_size))
|
||||||
if opt['dist']:
|
if opt['dist']:
|
||||||
train_sampler = DistIterSampler(self.train_set, world_size, self.rank, dataset_ratio)
|
self.train_sampler = DistIterSampler(self.train_set, self.world_size, self.rank, dataset_ratio)
|
||||||
self.total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
|
self.total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
|
||||||
else:
|
else:
|
||||||
train_sampler = None
|
self.train_sampler = None
|
||||||
self.train_loader = create_dataloader(self.train_set, dataset_opt, opt, train_sampler)
|
self.train_loader = create_dataloader(self.train_set, dataset_opt, opt, self.train_sampler)
|
||||||
if self.rank <= 0:
|
if self.rank <= 0:
|
||||||
self.logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
|
self.logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
|
||||||
len(self.train_set), train_size))
|
len(self.train_set), train_size))
|
||||||
|
@ -157,8 +147,6 @@ class Trainer:
|
||||||
print("Data fetch: %f" % (time() - _t))
|
print("Data fetch: %f" % (time() - _t))
|
||||||
_t = time()
|
_t = time()
|
||||||
|
|
||||||
#self.tb_logger.add_graph(self.model.netsG['generator'].module, input_to_model=torch.randn((1,3,32,32), device='cuda:0'))
|
|
||||||
|
|
||||||
opt = self.opt
|
opt = self.opt
|
||||||
self.current_step += 1
|
self.current_step += 1
|
||||||
#### update learning rate
|
#### update learning rate
|
||||||
|
@ -191,6 +179,9 @@ class Trainer:
|
||||||
# tensorboard logger
|
# tensorboard logger
|
||||||
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
||||||
self.tb_logger.add_scalar(k, v, self.current_step)
|
self.tb_logger.add_scalar(k, v, self.current_step)
|
||||||
|
if opt['wandb']:
|
||||||
|
import wandb
|
||||||
|
wandb.log(logs)
|
||||||
self.logger.info(message)
|
self.logger.info(message)
|
||||||
|
|
||||||
#### save models and training states
|
#### save models and training states
|
||||||
|
@ -216,8 +207,8 @@ class Trainer:
|
||||||
val_tqdm = tqdm(self.val_loader)
|
val_tqdm = tqdm(self.val_loader)
|
||||||
for val_data in val_tqdm:
|
for val_data in val_tqdm:
|
||||||
idx += 1
|
idx += 1
|
||||||
for b in range(len(val_data['LQ_path'])):
|
for b in range(len(val_data['GT_path'])):
|
||||||
img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][b]))[0]
|
img_name = os.path.splitext(os.path.basename(val_data['GT_path'][b]))[0]
|
||||||
img_dir = os.path.join(opt['path']['val_images'], img_name)
|
img_dir = os.path.join(opt['path']['val_images'], img_name)
|
||||||
util.mkdir(img_dir)
|
util.mkdir(img_dir)
|
||||||
|
|
||||||
|
@ -228,14 +219,16 @@ class Trainer:
|
||||||
if visuals is None:
|
if visuals is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# calculate PSNR
|
|
||||||
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
|
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
|
||||||
gt_img = util.tensor2img(visuals['GT'][b]) # uint8
|
# calculate PSNR
|
||||||
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
if self.val_compute_psnr:
|
||||||
avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
gt_img = util.tensor2img(visuals['GT'][b]) # uint8
|
||||||
|
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
|
||||||
|
avg_psnr += util.calculate_psnr(sr_img, gt_img)
|
||||||
|
|
||||||
# calculate fea loss
|
# calculate fea loss
|
||||||
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
|
if self.val_compute_fea:
|
||||||
|
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
|
||||||
|
|
||||||
# Save SR images for reference
|
# Save SR images for reference
|
||||||
img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step)
|
img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step)
|
||||||
|
@ -280,10 +273,24 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_adalatent_mi1_rrdb4x_6bl_resdisc.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_adalatent_mi1_rrdb4x_6bl_nolatent.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
opt = option.parse(args.opt, is_train=True)
|
opt = option.parse(args.opt, is_train=True)
|
||||||
trainer = Trainer()
|
trainer = Trainer()
|
||||||
|
|
||||||
|
#### distributed training settings
|
||||||
|
if args.launcher == 'none': # disabled distributed training
|
||||||
|
opt['dist'] = False
|
||||||
|
trainer.rank = -1
|
||||||
|
print('Disabled distributed training.')
|
||||||
|
|
||||||
|
else:
|
||||||
|
opt['dist'] = True
|
||||||
|
init_dist('nccl')
|
||||||
|
trainer.world_size = torch.distributed.get_world_size()
|
||||||
|
trainer.rank = torch.distributed.get_rank()
|
||||||
|
|
||||||
trainer.init(opt, args.launcher)
|
trainer.init(opt, args.launcher)
|
||||||
trainer.do_training()
|
trainer.do_training()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user