Enable usage of wandb

This commit is contained in:
James Betker 2020-11-11 21:48:56 -07:00
parent 1c065c41b4
commit 88f349bdf1
4 changed files with 62 additions and 43 deletions

1
.gitignore vendored
View File

@ -4,6 +4,7 @@ tb_logger/*
datasets/*
options/*
codes/*.txt
codes/wandb/*
.vscode
*.html

View File

@ -67,6 +67,9 @@ class ExtensibleTrainer(BaseModel):
if not net['trainable']:
new_net.eval()
if net['wandb_debug']:
import wandb
wandb.watch(new_net, log='all', log_freq=3)
# Initialize the train/eval steps
self.step_names = []

View File

@ -33,6 +33,11 @@ class Trainer:
self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'] else True
self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'] else True
#### wandb init
if opt['wandb']:
import wandb
wandb.init(project=opt['name'])
#### loading resume state if exists
if opt['path'].get('resume_state', None):
# distributed resuming: all load into default GPU
@ -174,6 +179,9 @@ class Trainer:
# tensorboard logger
if opt['use_tb_logger'] and 'debug' not in opt['name']:
self.tb_logger.add_scalar(k, v, self.current_step)
if opt['wandb']:
import wandb
wandb.log(logs)
self.logger.info(message)
#### save models and training states
@ -265,7 +273,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_adalatent_mi1_rrdb4x_6bl.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_adalatent_mi1_rrdb4x_6bl_pyrrrdb_disc.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()

View File

@ -13,40 +13,30 @@ from data import create_dataloader, create_dataset
from models.ExtensibleTrainer import ExtensibleTrainer
from time import time
class Trainer:
def init_dist(self, backend, **kwargs):
# These packages have globals that screw with Windows, so only import them if needed.
import torch.distributed as dist
import torch.multiprocessing as mp
def init_dist(backend, **kwargs):
# These packages have globals that screw with Windows, so only import them if needed.
import torch.distributed as dist
import torch.multiprocessing as mp
"""initialization for distributed training"""
if mp.get_start_method(allow_none=True) != 'spawn':
mp.set_start_method('spawn')
self.rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(self.rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
"""initialization for distributed training"""
if mp.get_start_method(allow_none=True) != 'spawn':
mp.set_start_method('spawn')
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
class Trainer:
def init(self, opt, launcher, all_networks={}):
self._profile = False
self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'] else True
self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'] else True
#### distributed training settings
if len(opt['gpu_ids']) == 1 and torch.cuda.device_count() > 1:
gpu = input(
'I noticed you have multiple GPUs. Starting two jobs on the same GPU sucks. Please confirm which GPU'
'you want to use. Press enter to use the specified one [%s]' % (opt['gpu_ids']))
if gpu:
opt['gpu_ids'] = [int(gpu)]
if launcher == 'none': # disabled distributed training
opt['dist'] = False
self.rank = -1
print('Disabled distributed training.')
else:
opt['dist'] = True
self.init_dist()
world_size = torch.distributed.get_world_size()
self.rank = torch.distributed.get_rank()
#### wandb init
if opt['wandb']:
import wandb
wandb.init(project=opt['name'])
#### loading resume state if exists
if opt['path'].get('resume_state', None):
@ -115,11 +105,11 @@ class Trainer:
total_iters = int(opt['train']['niter'])
self.total_epochs = int(math.ceil(total_iters / train_size))
if opt['dist']:
train_sampler = DistIterSampler(self.train_set, world_size, self.rank, dataset_ratio)
self.train_sampler = DistIterSampler(self.train_set, self.world_size, self.rank, dataset_ratio)
self.total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
else:
train_sampler = None
self.train_loader = create_dataloader(self.train_set, dataset_opt, opt, train_sampler)
self.train_sampler = None
self.train_loader = create_dataloader(self.train_set, dataset_opt, opt, self.train_sampler)
if self.rank <= 0:
self.logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
len(self.train_set), train_size))
@ -157,8 +147,6 @@ class Trainer:
print("Data fetch: %f" % (time() - _t))
_t = time()
#self.tb_logger.add_graph(self.model.netsG['generator'].module, input_to_model=torch.randn((1,3,32,32), device='cuda:0'))
opt = self.opt
self.current_step += 1
#### update learning rate
@ -191,6 +179,9 @@ class Trainer:
# tensorboard logger
if opt['use_tb_logger'] and 'debug' not in opt['name']:
self.tb_logger.add_scalar(k, v, self.current_step)
if opt['wandb']:
import wandb
wandb.log(logs)
self.logger.info(message)
#### save models and training states
@ -216,8 +207,8 @@ class Trainer:
val_tqdm = tqdm(self.val_loader)
for val_data in val_tqdm:
idx += 1
for b in range(len(val_data['LQ_path'])):
img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][b]))[0]
for b in range(len(val_data['GT_path'])):
img_name = os.path.splitext(os.path.basename(val_data['GT_path'][b]))[0]
img_dir = os.path.join(opt['path']['val_images'], img_name)
util.mkdir(img_dir)
@ -228,14 +219,16 @@ class Trainer:
if visuals is None:
continue
# calculate PSNR
sr_img = util.tensor2img(visuals['rlt'][b]) # uint8
gt_img = util.tensor2img(visuals['GT'][b]) # uint8
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
avg_psnr += util.calculate_psnr(sr_img, gt_img)
# calculate PSNR
if self.val_compute_psnr:
gt_img = util.tensor2img(visuals['GT'][b]) # uint8
sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
avg_psnr += util.calculate_psnr(sr_img, gt_img)
# calculate fea loss
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
if self.val_compute_fea:
avg_fea_loss += self.model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b])
# Save SR images for reference
img_base_name = '{:s}_{:d}.png'.format(img_name, self.current_step)
@ -280,10 +273,24 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_adalatent_mi1_rrdb4x_6bl_resdisc.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_adalatent_mi1_rrdb4x_6bl_nolatent.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)
trainer = Trainer()
#### distributed training settings
if args.launcher == 'none': # disabled distributed training
opt['dist'] = False
trainer.rank = -1
print('Disabled distributed training.')
else:
opt['dist'] = True
init_dist('nccl')
trainer.world_size = torch.distributed.get_world_size()
trainer.rank = torch.distributed.get_rank()
trainer.init(opt, args.launcher)
trainer.do_training()