Go back to apex DDP, fix distributed bugs

This commit is contained in:
James Betker 2020-12-04 16:39:21 -07:00
parent 7a81d4e2f4
commit 8a83b1c716
5 changed files with 17 additions and 29 deletions

View File

@ -4,7 +4,7 @@ import os
import torch import torch
from torch.nn.parallel import DataParallel from torch.nn.parallel import DataParallel
import torch.nn as nn import torch.nn as nn
from torch.nn.parallel.distributed import DistributedDataParallel from apex.parallel import DistributedDataParallel
import models.lr_scheduler as lr_scheduler import models.lr_scheduler as lr_scheduler
import models.networks as networks import models.networks as networks
@ -106,9 +106,7 @@ class ExtensibleTrainer(BaseModel):
all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()] all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
for anet in all_networks: for anet in all_networks:
if opt['dist']: if opt['dist']:
dnet = DistributedDataParallel(anet, dnet = DistributedDataParallel(anet, delay_allreduce=True)
device_ids=[torch.cuda.current_device()],
find_unused_parameters=False)
else: else:
dnet = DataParallel(anet, device_ids=opt['gpu_ids']) dnet = DataParallel(anet, device_ids=opt['gpu_ids'])
if self.is_train: if self.is_train:

View File

@ -97,8 +97,8 @@ class BaseModel():
return save_path return save_path
def load_network(self, load_path, network, strict=True): def load_network(self, load_path, network, strict=True):
if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): #if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
network = network.module network = network.module
load_net = torch.load(load_path) load_net = torch.load(load_path)
# Support loading torch.save()s for whole models as well as just state_dicts. # Support loading torch.save()s for whole models as well as just state_dicts.
@ -109,21 +109,7 @@ class BaseModel():
load_net_clean = OrderedDict() # remove unnecessary 'module.' load_net_clean = OrderedDict() # remove unnecessary 'module.'
for k, v in load_net.items(): for k, v in load_net.items():
if k.startswith('module.'): if k.startswith('module.'):
load_net_clean[k[7:]] = v load_net_clean[k.replace('module.', '')] = v
if k.startswith('generator'): # Hack to fix ESRGAN pretrained model.
load_net_clean[k[10:]] = v
if 'RRDB_trunk' in k or is_srflow: # Hacks to fix SRFlow imports, which uses some strange RDB names.
is_srflow = True
fixed_key = k.replace('RRDB_trunk', 'body')
if '.RDB' in fixed_key:
fixed_key = fixed_key.replace('.RDB', '.rdb')
elif '.upconv' in fixed_key:
fixed_key = fixed_key.replace('.upconv', '.conv_up')
elif '.trunk_conv' in fixed_key:
fixed_key = fixed_key.replace('.trunk_conv', '.conv_body')
elif '.HRconv' in fixed_key:
fixed_key = fixed_key.replace('.HRconv', '.conv_hr')
load_net_clean[fixed_key] = v
else: else:
load_net_clean[k] = v load_net_clean[k] = v
network.load_state_dict(load_net_clean, strict=strict) network.load_state_dict(load_net_clean, strict=strict)

View File

@ -1,9 +1,9 @@
# Base class for an evaluator, which is responsible for feeding test data through a model and evaluating the response. # Base class for an evaluator, which is responsible for feeding test data through a model and evaluating the response.
class Evaluator: class Evaluator:
def __init__(self, model, opt_eval, env): def __init__(self, model, opt_eval, env):
self.model = model self.model = model.module if hasattr(model, 'module') else model
self.opt = opt_eval self.opt = opt_eval
self.env = env self.env = env
def perform_eval(self): def perform_eval(self):
return {} return {}

View File

@ -25,15 +25,18 @@ class FlowGaussianNll(evaluator.Evaluator):
def perform_eval(self): def perform_eval(self):
total_zs = 0 total_zs = 0
z_loss = 0 z_loss = 0
self.model.eval()
with torch.no_grad(): with torch.no_grad():
print("Evaluating FlowGaussianNll..") print("Evaluating FlowGaussianNll..")
for batch in tqdm(self.dataloader): for batch in tqdm(self.dataloader):
z, _, _ = self.model(gt=batch['GT'], dev = self.env['device']
lr=batch['LQ'], z, _, _ = self.model(gt=batch['GT'].to(dev),
lr=batch['LQ'].to(dev),
epses=[], epses=[],
reverse=False, reverse=False,
add_gt_noise=False) add_gt_noise=False)
for z_ in z: for z_ in z:
z_loss += GaussianDiag.logp(None, None, z_).mean() z_loss += GaussianDiag.logp(None, None, z_).mean()
total_zs += 1 total_zs += 1
self.model.train()
return {"gaussian_diff": z_loss / total_zs} return {"gaussian_diff": z_loss / total_zs}

View File

@ -255,13 +255,14 @@ class Trainer:
self.tb_logger.add_scalar('val_psnr', avg_psnr, self.current_step) self.tb_logger.add_scalar('val_psnr', avg_psnr, self.current_step)
self.tb_logger.add_scalar('val_fea', avg_fea_loss, self.current_step) self.tb_logger.add_scalar('val_fea', avg_fea_loss, self.current_step)
if len(self.evaluators) != 0 and self.current_step % opt['train']['val_freq'] == 0: if len(self.evaluators) != 0 and self.current_step % opt['train']['val_freq'] == 0 and self.rank <= 0:
eval_dict = {} eval_dict = {}
for eval in self.evaluators: for eval in self.evaluators:
eval_dict.update(eval.perform_eval()) eval_dict.update(eval.perform_eval())
print("Evaluator results: ", eval_dict) if self.rank <= 0:
for ek, ev in eval_dict.items(): print("Evaluator results: ", eval_dict)
self.tb_logger.add_scalar(ek, ev, self.current_step) for ek, ev in eval_dict.items():
self.tb_logger.add_scalar(ek, ev, self.current_step)
def do_training(self): def do_training(self):
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step)) self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))