Go back to apex DDP, fix distributed bugs

This commit is contained in:
James Betker 2020-12-04 16:39:21 -07:00
parent 7a81d4e2f4
commit 8a83b1c716
5 changed files with 17 additions and 29 deletions

View File

@ -4,7 +4,7 @@ import os
import torch
from torch.nn.parallel import DataParallel
import torch.nn as nn
from torch.nn.parallel.distributed import DistributedDataParallel
from apex.parallel import DistributedDataParallel
import models.lr_scheduler as lr_scheduler
import models.networks as networks
@ -106,9 +106,7 @@ class ExtensibleTrainer(BaseModel):
all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
for anet in all_networks:
if opt['dist']:
dnet = DistributedDataParallel(anet,
device_ids=[torch.cuda.current_device()],
find_unused_parameters=False)
dnet = DistributedDataParallel(anet, delay_allreduce=True)
else:
dnet = DataParallel(anet, device_ids=opt['gpu_ids'])
if self.is_train:

View File

@ -97,8 +97,8 @@ class BaseModel():
return save_path
def load_network(self, load_path, network, strict=True):
if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
network = network.module
#if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
network = network.module
load_net = torch.load(load_path)
# Support loading torch.save()s for whole models as well as just state_dicts.
@ -109,21 +109,7 @@ class BaseModel():
load_net_clean = OrderedDict() # remove unnecessary 'module.'
for k, v in load_net.items():
if k.startswith('module.'):
load_net_clean[k[7:]] = v
if k.startswith('generator'): # Hack to fix ESRGAN pretrained model.
load_net_clean[k[10:]] = v
if 'RRDB_trunk' in k or is_srflow: # Hacks to fix SRFlow imports, which uses some strange RDB names.
is_srflow = True
fixed_key = k.replace('RRDB_trunk', 'body')
if '.RDB' in fixed_key:
fixed_key = fixed_key.replace('.RDB', '.rdb')
elif '.upconv' in fixed_key:
fixed_key = fixed_key.replace('.upconv', '.conv_up')
elif '.trunk_conv' in fixed_key:
fixed_key = fixed_key.replace('.trunk_conv', '.conv_body')
elif '.HRconv' in fixed_key:
fixed_key = fixed_key.replace('.HRconv', '.conv_hr')
load_net_clean[fixed_key] = v
load_net_clean[k.replace('module.', '')] = v
else:
load_net_clean[k] = v
network.load_state_dict(load_net_clean, strict=strict)

View File

@ -1,9 +1,9 @@
# Base class for an evaluator, which is responsible for feeding test data through a model and evaluating the response.
class Evaluator:
def __init__(self, model, opt_eval, env):
self.model = model
self.model = model.module if hasattr(model, 'module') else model
self.opt = opt_eval
self.env = env
def perform_eval(self):
return {}
return {}

View File

@ -25,15 +25,18 @@ class FlowGaussianNll(evaluator.Evaluator):
def perform_eval(self):
total_zs = 0
z_loss = 0
self.model.eval()
with torch.no_grad():
print("Evaluating FlowGaussianNll..")
for batch in tqdm(self.dataloader):
z, _, _ = self.model(gt=batch['GT'],
lr=batch['LQ'],
dev = self.env['device']
z, _, _ = self.model(gt=batch['GT'].to(dev),
lr=batch['LQ'].to(dev),
epses=[],
reverse=False,
add_gt_noise=False)
for z_ in z:
z_loss += GaussianDiag.logp(None, None, z_).mean()
total_zs += 1
self.model.train()
return {"gaussian_diff": z_loss / total_zs}

View File

@ -255,13 +255,14 @@ class Trainer:
self.tb_logger.add_scalar('val_psnr', avg_psnr, self.current_step)
self.tb_logger.add_scalar('val_fea', avg_fea_loss, self.current_step)
if len(self.evaluators) != 0 and self.current_step % opt['train']['val_freq'] == 0:
if len(self.evaluators) != 0 and self.current_step % opt['train']['val_freq'] == 0 and self.rank <= 0:
eval_dict = {}
for eval in self.evaluators:
eval_dict.update(eval.perform_eval())
print("Evaluator results: ", eval_dict)
for ek, ev in eval_dict.items():
self.tb_logger.add_scalar(ek, ev, self.current_step)
if self.rank <= 0:
print("Evaluator results: ", eval_dict)
for ek, ev in eval_dict.items():
self.tb_logger.add_scalar(ek, ev, self.current_step)
def do_training(self):
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))