From 8a83b1c7161a58e551f51b30ad8fabc4e576adc7 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 4 Dec 2020 16:39:21 -0700 Subject: [PATCH] Go back to apex DDP, fix distributed bugs --- codes/models/ExtensibleTrainer.py | 6 ++---- codes/models/base_model.py | 20 +++----------------- codes/models/eval/evaluator.py | 4 ++-- codes/models/eval/flow_gaussian_nll.py | 7 +++++-- codes/train.py | 9 +++++---- 5 files changed, 17 insertions(+), 29 deletions(-) diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 7d61a19d..451b5e1d 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -4,7 +4,7 @@ import os import torch from torch.nn.parallel import DataParallel import torch.nn as nn -from torch.nn.parallel.distributed import DistributedDataParallel +from apex.parallel import DistributedDataParallel import models.lr_scheduler as lr_scheduler import models.networks as networks @@ -106,9 +106,7 @@ class ExtensibleTrainer(BaseModel): all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()] for anet in all_networks: if opt['dist']: - dnet = DistributedDataParallel(anet, - device_ids=[torch.cuda.current_device()], - find_unused_parameters=False) + dnet = DistributedDataParallel(anet, delay_allreduce=True) else: dnet = DataParallel(anet, device_ids=opt['gpu_ids']) if self.is_train: diff --git a/codes/models/base_model.py b/codes/models/base_model.py index 55371407..75a5d44d 100644 --- a/codes/models/base_model.py +++ b/codes/models/base_model.py @@ -97,8 +97,8 @@ class BaseModel(): return save_path def load_network(self, load_path, network, strict=True): - if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): - network = network.module + #if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): + network = network.module load_net = torch.load(load_path) # Support loading torch.save()s for whole models as well as just state_dicts. @@ -109,21 +109,7 @@ class BaseModel(): load_net_clean = OrderedDict() # remove unnecessary 'module.' for k, v in load_net.items(): if k.startswith('module.'): - load_net_clean[k[7:]] = v - if k.startswith('generator'): # Hack to fix ESRGAN pretrained model. - load_net_clean[k[10:]] = v - if 'RRDB_trunk' in k or is_srflow: # Hacks to fix SRFlow imports, which uses some strange RDB names. - is_srflow = True - fixed_key = k.replace('RRDB_trunk', 'body') - if '.RDB' in fixed_key: - fixed_key = fixed_key.replace('.RDB', '.rdb') - elif '.upconv' in fixed_key: - fixed_key = fixed_key.replace('.upconv', '.conv_up') - elif '.trunk_conv' in fixed_key: - fixed_key = fixed_key.replace('.trunk_conv', '.conv_body') - elif '.HRconv' in fixed_key: - fixed_key = fixed_key.replace('.HRconv', '.conv_hr') - load_net_clean[fixed_key] = v + load_net_clean[k.replace('module.', '')] = v else: load_net_clean[k] = v network.load_state_dict(load_net_clean, strict=strict) diff --git a/codes/models/eval/evaluator.py b/codes/models/eval/evaluator.py index 6e8c665b..5f0a364f 100644 --- a/codes/models/eval/evaluator.py +++ b/codes/models/eval/evaluator.py @@ -1,9 +1,9 @@ # Base class for an evaluator, which is responsible for feeding test data through a model and evaluating the response. class Evaluator: def __init__(self, model, opt_eval, env): - self.model = model + self.model = model.module if hasattr(model, 'module') else model self.opt = opt_eval self.env = env def perform_eval(self): - return {} \ No newline at end of file + return {} diff --git a/codes/models/eval/flow_gaussian_nll.py b/codes/models/eval/flow_gaussian_nll.py index f887b7c3..eed85622 100644 --- a/codes/models/eval/flow_gaussian_nll.py +++ b/codes/models/eval/flow_gaussian_nll.py @@ -25,15 +25,18 @@ class FlowGaussianNll(evaluator.Evaluator): def perform_eval(self): total_zs = 0 z_loss = 0 + self.model.eval() with torch.no_grad(): print("Evaluating FlowGaussianNll..") for batch in tqdm(self.dataloader): - z, _, _ = self.model(gt=batch['GT'], - lr=batch['LQ'], + dev = self.env['device'] + z, _, _ = self.model(gt=batch['GT'].to(dev), + lr=batch['LQ'].to(dev), epses=[], reverse=False, add_gt_noise=False) for z_ in z: z_loss += GaussianDiag.logp(None, None, z_).mean() total_zs += 1 + self.model.train() return {"gaussian_diff": z_loss / total_zs} diff --git a/codes/train.py b/codes/train.py index 9995b9f6..aa124dd3 100644 --- a/codes/train.py +++ b/codes/train.py @@ -255,13 +255,14 @@ class Trainer: self.tb_logger.add_scalar('val_psnr', avg_psnr, self.current_step) self.tb_logger.add_scalar('val_fea', avg_fea_loss, self.current_step) - if len(self.evaluators) != 0 and self.current_step % opt['train']['val_freq'] == 0: + if len(self.evaluators) != 0 and self.current_step % opt['train']['val_freq'] == 0 and self.rank <= 0: eval_dict = {} for eval in self.evaluators: eval_dict.update(eval.perform_eval()) - print("Evaluator results: ", eval_dict) - for ek, ev in eval_dict.items(): - self.tb_logger.add_scalar(ek, ev, self.current_step) + if self.rank <= 0: + print("Evaluator results: ", eval_dict) + for ek, ev in eval_dict.items(): + self.tb_logger.add_scalar(ek, ev, self.current_step) def do_training(self): self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))