Go back to apex DDP, fix distributed bugs
This commit is contained in:
parent
7a81d4e2f4
commit
8a83b1c716
|
@ -4,7 +4,7 @@ import os
|
|||
import torch
|
||||
from torch.nn.parallel import DataParallel
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
from apex.parallel import DistributedDataParallel
|
||||
|
||||
import models.lr_scheduler as lr_scheduler
|
||||
import models.networks as networks
|
||||
|
@ -106,9 +106,7 @@ class ExtensibleTrainer(BaseModel):
|
|||
all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
|
||||
for anet in all_networks:
|
||||
if opt['dist']:
|
||||
dnet = DistributedDataParallel(anet,
|
||||
device_ids=[torch.cuda.current_device()],
|
||||
find_unused_parameters=False)
|
||||
dnet = DistributedDataParallel(anet, delay_allreduce=True)
|
||||
else:
|
||||
dnet = DataParallel(anet, device_ids=opt['gpu_ids'])
|
||||
if self.is_train:
|
||||
|
|
|
@ -97,8 +97,8 @@ class BaseModel():
|
|||
return save_path
|
||||
|
||||
def load_network(self, load_path, network, strict=True):
|
||||
if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
|
||||
network = network.module
|
||||
#if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
|
||||
network = network.module
|
||||
load_net = torch.load(load_path)
|
||||
|
||||
# Support loading torch.save()s for whole models as well as just state_dicts.
|
||||
|
@ -109,21 +109,7 @@ class BaseModel():
|
|||
load_net_clean = OrderedDict() # remove unnecessary 'module.'
|
||||
for k, v in load_net.items():
|
||||
if k.startswith('module.'):
|
||||
load_net_clean[k[7:]] = v
|
||||
if k.startswith('generator'): # Hack to fix ESRGAN pretrained model.
|
||||
load_net_clean[k[10:]] = v
|
||||
if 'RRDB_trunk' in k or is_srflow: # Hacks to fix SRFlow imports, which uses some strange RDB names.
|
||||
is_srflow = True
|
||||
fixed_key = k.replace('RRDB_trunk', 'body')
|
||||
if '.RDB' in fixed_key:
|
||||
fixed_key = fixed_key.replace('.RDB', '.rdb')
|
||||
elif '.upconv' in fixed_key:
|
||||
fixed_key = fixed_key.replace('.upconv', '.conv_up')
|
||||
elif '.trunk_conv' in fixed_key:
|
||||
fixed_key = fixed_key.replace('.trunk_conv', '.conv_body')
|
||||
elif '.HRconv' in fixed_key:
|
||||
fixed_key = fixed_key.replace('.HRconv', '.conv_hr')
|
||||
load_net_clean[fixed_key] = v
|
||||
load_net_clean[k.replace('module.', '')] = v
|
||||
else:
|
||||
load_net_clean[k] = v
|
||||
network.load_state_dict(load_net_clean, strict=strict)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Base class for an evaluator, which is responsible for feeding test data through a model and evaluating the response.
|
||||
class Evaluator:
|
||||
def __init__(self, model, opt_eval, env):
|
||||
self.model = model
|
||||
self.model = model.module if hasattr(model, 'module') else model
|
||||
self.opt = opt_eval
|
||||
self.env = env
|
||||
|
||||
|
|
|
@ -25,15 +25,18 @@ class FlowGaussianNll(evaluator.Evaluator):
|
|||
def perform_eval(self):
|
||||
total_zs = 0
|
||||
z_loss = 0
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
print("Evaluating FlowGaussianNll..")
|
||||
for batch in tqdm(self.dataloader):
|
||||
z, _, _ = self.model(gt=batch['GT'],
|
||||
lr=batch['LQ'],
|
||||
dev = self.env['device']
|
||||
z, _, _ = self.model(gt=batch['GT'].to(dev),
|
||||
lr=batch['LQ'].to(dev),
|
||||
epses=[],
|
||||
reverse=False,
|
||||
add_gt_noise=False)
|
||||
for z_ in z:
|
||||
z_loss += GaussianDiag.logp(None, None, z_).mean()
|
||||
total_zs += 1
|
||||
self.model.train()
|
||||
return {"gaussian_diff": z_loss / total_zs}
|
||||
|
|
|
@ -255,13 +255,14 @@ class Trainer:
|
|||
self.tb_logger.add_scalar('val_psnr', avg_psnr, self.current_step)
|
||||
self.tb_logger.add_scalar('val_fea', avg_fea_loss, self.current_step)
|
||||
|
||||
if len(self.evaluators) != 0 and self.current_step % opt['train']['val_freq'] == 0:
|
||||
if len(self.evaluators) != 0 and self.current_step % opt['train']['val_freq'] == 0 and self.rank <= 0:
|
||||
eval_dict = {}
|
||||
for eval in self.evaluators:
|
||||
eval_dict.update(eval.perform_eval())
|
||||
print("Evaluator results: ", eval_dict)
|
||||
for ek, ev in eval_dict.items():
|
||||
self.tb_logger.add_scalar(ek, ev, self.current_step)
|
||||
if self.rank <= 0:
|
||||
print("Evaluator results: ", eval_dict)
|
||||
for ek, ev in eval_dict.items():
|
||||
self.tb_logger.add_scalar(ek, ev, self.current_step)
|
||||
|
||||
def do_training(self):
|
||||
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
|
||||
|
|
Loading…
Reference in New Issue
Block a user