Enable usage of wandb

This commit is contained in:
James Betker 2020-11-11 21:48:56 -07:00
parent 1c065c41b4
commit 88f349bdf1
4 changed files with 62 additions and 43 deletions

1
.gitignore vendored
View File

@ -4,6 +4,7 @@ tb_logger/*
datasets/* datasets/*
options/* options/*
codes/*.txt codes/*.txt
codes/wandb/*
.vscode .vscode
*.html *.html

View File

@ -67,6 +67,9 @@ class ExtensibleTrainer(BaseModel):
if not net['trainable']: if not net['trainable']:
new_net.eval() new_net.eval()
if net['wandb_debug']:
import wandb
wandb.watch(new_net, log='all', log_freq=3)
# Initialize the train/eval steps # Initialize the train/eval steps
self.step_names = [] self.step_names = []

View File

@ -33,6 +33,11 @@ class Trainer:
self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'] else True self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'] else True
self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'] else True self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'] else True
#### wandb init
if opt['wandb']:
import wandb
wandb.init(project=opt['name'])
#### loading resume state if exists #### loading resume state if exists
if opt['path'].get('resume_state', None): if opt['path'].get('resume_state', None):
# distributed resuming: all load into default GPU # distributed resuming: all load into default GPU
@ -174,6 +179,9 @@ class Trainer:
# tensorboard logger # tensorboard logger
if opt['use_tb_logger'] and 'debug' not in opt['name']: if opt['use_tb_logger'] and 'debug' not in opt['name']:
self.tb_logger.add_scalar(k, v, self.current_step) self.tb_logger.add_scalar(k, v, self.current_step)
if opt['wandb']:
import wandb
wandb.log(logs)
self.logger.info(message) self.logger.info(message)
#### save models and training states #### save models and training states
@ -265,7 +273,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_adalatent_mi1_rrdb4x_6bl.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_adalatent_mi1_rrdb4x_6bl_pyrrrdb_disc.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()

View File

@ -13,40 +13,30 @@ from data import create_dataloader, create_dataset
from models.ExtensibleTrainer import ExtensibleTrainer from models.ExtensibleTrainer import ExtensibleTrainer
from time import time from time import time
class Trainer: def init_dist(backend, **kwargs):
def init_dist(self, backend, **kwargs): # These packages have globals that screw with Windows, so only import them if needed.
# These packages have globals that screw with Windows, so only import them if needed. import torch.distributed as dist
import torch.distributed as dist import torch.multiprocessing as mp
import torch.multiprocessing as mp
"""initialization for distributed training""" """initialization for distributed training"""
if mp.get_start_method(allow_none=True) != 'spawn': if mp.get_start_method(allow_none=True) != 'spawn':
mp.set_start_method('spawn') mp.set_start_method('spawn')
self.rank = int(os.environ['RANK']) rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count() num_gpus = torch.cuda.device_count()
torch.cuda.set_device(self.rank % num_gpus) torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs) dist.init_process_group(backend=backend, **kwargs)
class Trainer:
def init(self, opt, launcher, all_networks={}): def init(self, opt, launcher, all_networks={}):
self._profile = False self._profile = False
self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'] else True
self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'] else True
#### distributed training settings #### wandb init
if len(opt['gpu_ids']) == 1 and torch.cuda.device_count() > 1: if opt['wandb']:
gpu = input( import wandb
'I noticed you have multiple GPUs. Starting two jobs on the same GPU sucks. Please confirm which GPU' wandb.init(project=opt['name'])
'you want to use. Press enter to use the specified one [%s]' % (opt['gpu_ids']))
if gpu:
opt['gpu_ids'] = [int(gpu)]
if launcher == 'none': # disabled distributed training
opt['dist'] = False
self.rank = -1
print('Disabled distributed training.')
else:
opt['dist'] = True
self.init_dist()
world_size = torch.distributed.get_world_size()
self.rank = torch.distributed.get_rank()
#### loading resume state if exists #### loading resume state if exists
if opt['path'].get('resume_state', None): if opt['path'].get('resume_state', None):
@ -115,11 +105,11 @@ class Trainer:
total_iters = int(opt['train']['niter']) total_iters = int(opt['train']['niter'])
self.total_epochs = int(math.ceil(total_iters / train_size)) self.total_epochs = int(math.ceil(total_iters / train_size))
if opt['dist']: if opt['dist']:
train_sampler = DistIterSampler(self.train_set, world_size, self.rank, dataset_ratio) self.train_sampler = DistIterSampler(self.train_set, self.world_size, self.rank, dataset_ratio)
self.total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio))) self.total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
else: else:
train_sampler = None self.train_sampler = None
self.train_loader = create_dataloader(self.train_set, dataset_opt, opt, train_sampler) self.train_loader = create_dataloader(self.train_set, dataset_opt, opt, self.train_sampler)
if self.rank <= 0: if self.rank <= 0:
self.logger.info('Number of train images: {:,d}, iters: {:,d}'.format( self.logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
len(self.train_set), train_size)) len(self.train_set), train_size))
@ -157,8 +147,6 @@ class Trainer:
print("Data fetch: %f" % (time() - _t)) print("Data fetch: %f" % (time() - _t))
_t = time() _t = time()
#self.tb_logger.add_graph(self.model.netsG['generator'].module, input_to_model=torch.randn((1,3,32,32), device='cuda:0'))
opt = self.opt opt = self.opt
self.current_step += 1 self.current_step += 1
#### update learning rate #### update learning rate
@ -191,6 +179,9 @@ class Trainer:
# tensorboard logger # tensorboard logger
if opt['use_tb_logger'] and 'debug' not in opt['name']: if opt['use_tb_logger'] and 'debug' not in opt['name']:
self.tb_logger.add_scalar(k, v, self.current_step) self.tb_logger.add_scalar(k, v, self.current_step)
if opt['wandb']:
import wandb
wandb.log(logs)
self.logger.info(message) self.logger.info(message)
#### save models and training states #### save models and training states
@ -216,8 +207,8 @@ class Trainer:
val_tqdm = tqdm(self.val_loader) val_tqdm = tqdm(self.val_loader)
for val_data in val_tqdm: for val_data in val_tqdm:
idx += 1 idx += 1
for b in range(len(val_data['LQ_path'])): for b in range(len(val_data['GT_path'])):
img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][b]))[0] img_name = os.path.splitext(os.path.basename(val_data['GT_path'][b]))[0]
img_dir = os.path.join(opt['path']['val_images'], img_name) img_dir = os.path.join(opt['path']['val_images'], img_name)
util.mkdir(img_dir) util.mkdir(img_dir)
@ -228,14 +219,16 @@ class Trainer:
if visuals is None: if visuals is None:
continue continue
# calculate PSNR
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8 sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
gt_img = util.tensor2img(visuals['GT'][b]) # uint8 # calculate PSNR
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) if self.val_compute_psnr:
avg_psnr += util.calculate_psnr(sr_img, gt_img) gt_img = util.tensor2img(visuals['GT'][b]) # uint8
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
avg_psnr += util.calculate_psnr(sr_img, gt_img)
# calculate fea loss # calculate fea loss
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b]) if self.val_compute_fea:
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
# Save SR images for reference # Save SR images for reference
img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step) img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step)
@ -280,10 +273,24 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_adalatent_mi1_rrdb4x_6bl_resdisc.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_adalatent_mi1_rrdb4x_6bl_nolatent.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()
opt = option.parse(args.opt, is_train=True) opt = option.parse(args.opt, is_train=True)
trainer = Trainer() trainer = Trainer()
#### distributed training settings
if args.launcher == 'none': # disabled distributed training
opt['dist'] = False
trainer.rank = -1
print('Disabled distributed training.')
else:
opt['dist'] = True
init_dist('nccl')
trainer.world_size = torch.distributed.get_world_size()
trainer.rank = torch.distributed.get_rank()
trainer.init(opt, args.launcher) trainer.init(opt, args.launcher)
trainer.do_training() trainer.do_training()