Go back to apex DDP, fix distributed bugs
This commit is contained in:
parent
7a81d4e2f4
commit
8a83b1c716
|
@ -4,7 +4,7 @@ import os
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.parallel import DataParallel
|
from torch.nn.parallel import DataParallel
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
from apex.parallel import DistributedDataParallel
|
||||||
|
|
||||||
import models.lr_scheduler as lr_scheduler
|
import models.lr_scheduler as lr_scheduler
|
||||||
import models.networks as networks
|
import models.networks as networks
|
||||||
|
@ -106,9 +106,7 @@ class ExtensibleTrainer(BaseModel):
|
||||||
all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
|
all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
|
||||||
for anet in all_networks:
|
for anet in all_networks:
|
||||||
if opt['dist']:
|
if opt['dist']:
|
||||||
dnet = DistributedDataParallel(anet,
|
dnet = DistributedDataParallel(anet, delay_allreduce=True)
|
||||||
device_ids=[torch.cuda.current_device()],
|
|
||||||
find_unused_parameters=False)
|
|
||||||
else:
|
else:
|
||||||
dnet = DataParallel(anet, device_ids=opt['gpu_ids'])
|
dnet = DataParallel(anet, device_ids=opt['gpu_ids'])
|
||||||
if self.is_train:
|
if self.is_train:
|
||||||
|
|
|
@ -97,8 +97,8 @@ class BaseModel():
|
||||||
return save_path
|
return save_path
|
||||||
|
|
||||||
def load_network(self, load_path, network, strict=True):
|
def load_network(self, load_path, network, strict=True):
|
||||||
if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
|
#if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
|
||||||
network = network.module
|
network = network.module
|
||||||
load_net = torch.load(load_path)
|
load_net = torch.load(load_path)
|
||||||
|
|
||||||
# Support loading torch.save()s for whole models as well as just state_dicts.
|
# Support loading torch.save()s for whole models as well as just state_dicts.
|
||||||
|
@ -109,21 +109,7 @@ class BaseModel():
|
||||||
load_net_clean = OrderedDict() # remove unnecessary 'module.'
|
load_net_clean = OrderedDict() # remove unnecessary 'module.'
|
||||||
for k, v in load_net.items():
|
for k, v in load_net.items():
|
||||||
if k.startswith('module.'):
|
if k.startswith('module.'):
|
||||||
load_net_clean[k[7:]] = v
|
load_net_clean[k.replace('module.', '')] = v
|
||||||
if k.startswith('generator'): # Hack to fix ESRGAN pretrained model.
|
|
||||||
load_net_clean[k[10:]] = v
|
|
||||||
if 'RRDB_trunk' in k or is_srflow: # Hacks to fix SRFlow imports, which uses some strange RDB names.
|
|
||||||
is_srflow = True
|
|
||||||
fixed_key = k.replace('RRDB_trunk', 'body')
|
|
||||||
if '.RDB' in fixed_key:
|
|
||||||
fixed_key = fixed_key.replace('.RDB', '.rdb')
|
|
||||||
elif '.upconv' in fixed_key:
|
|
||||||
fixed_key = fixed_key.replace('.upconv', '.conv_up')
|
|
||||||
elif '.trunk_conv' in fixed_key:
|
|
||||||
fixed_key = fixed_key.replace('.trunk_conv', '.conv_body')
|
|
||||||
elif '.HRconv' in fixed_key:
|
|
||||||
fixed_key = fixed_key.replace('.HRconv', '.conv_hr')
|
|
||||||
load_net_clean[fixed_key] = v
|
|
||||||
else:
|
else:
|
||||||
load_net_clean[k] = v
|
load_net_clean[k] = v
|
||||||
network.load_state_dict(load_net_clean, strict=strict)
|
network.load_state_dict(load_net_clean, strict=strict)
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
# Base class for an evaluator, which is responsible for feeding test data through a model and evaluating the response.
|
# Base class for an evaluator, which is responsible for feeding test data through a model and evaluating the response.
|
||||||
class Evaluator:
|
class Evaluator:
|
||||||
def __init__(self, model, opt_eval, env):
|
def __init__(self, model, opt_eval, env):
|
||||||
self.model = model
|
self.model = model.module if hasattr(model, 'module') else model
|
||||||
self.opt = opt_eval
|
self.opt = opt_eval
|
||||||
self.env = env
|
self.env = env
|
||||||
|
|
||||||
def perform_eval(self):
|
def perform_eval(self):
|
||||||
return {}
|
return {}
|
||||||
|
|
|
@ -25,15 +25,18 @@ class FlowGaussianNll(evaluator.Evaluator):
|
||||||
def perform_eval(self):
|
def perform_eval(self):
|
||||||
total_zs = 0
|
total_zs = 0
|
||||||
z_loss = 0
|
z_loss = 0
|
||||||
|
self.model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
print("Evaluating FlowGaussianNll..")
|
print("Evaluating FlowGaussianNll..")
|
||||||
for batch in tqdm(self.dataloader):
|
for batch in tqdm(self.dataloader):
|
||||||
z, _, _ = self.model(gt=batch['GT'],
|
dev = self.env['device']
|
||||||
lr=batch['LQ'],
|
z, _, _ = self.model(gt=batch['GT'].to(dev),
|
||||||
|
lr=batch['LQ'].to(dev),
|
||||||
epses=[],
|
epses=[],
|
||||||
reverse=False,
|
reverse=False,
|
||||||
add_gt_noise=False)
|
add_gt_noise=False)
|
||||||
for z_ in z:
|
for z_ in z:
|
||||||
z_loss += GaussianDiag.logp(None, None, z_).mean()
|
z_loss += GaussianDiag.logp(None, None, z_).mean()
|
||||||
total_zs += 1
|
total_zs += 1
|
||||||
|
self.model.train()
|
||||||
return {"gaussian_diff": z_loss / total_zs}
|
return {"gaussian_diff": z_loss / total_zs}
|
||||||
|
|
|
@ -255,13 +255,14 @@ class Trainer:
|
||||||
self.tb_logger.add_scalar('val_psnr', avg_psnr, self.current_step)
|
self.tb_logger.add_scalar('val_psnr', avg_psnr, self.current_step)
|
||||||
self.tb_logger.add_scalar('val_fea', avg_fea_loss, self.current_step)
|
self.tb_logger.add_scalar('val_fea', avg_fea_loss, self.current_step)
|
||||||
|
|
||||||
if len(self.evaluators) != 0 and self.current_step % opt['train']['val_freq'] == 0:
|
if len(self.evaluators) != 0 and self.current_step % opt['train']['val_freq'] == 0 and self.rank <= 0:
|
||||||
eval_dict = {}
|
eval_dict = {}
|
||||||
for eval in self.evaluators:
|
for eval in self.evaluators:
|
||||||
eval_dict.update(eval.perform_eval())
|
eval_dict.update(eval.perform_eval())
|
||||||
print("Evaluator results: ", eval_dict)
|
if self.rank <= 0:
|
||||||
for ek, ev in eval_dict.items():
|
print("Evaluator results: ", eval_dict)
|
||||||
self.tb_logger.add_scalar(ek, ev, self.current_step)
|
for ek, ev in eval_dict.items():
|
||||||
|
self.tb_logger.add_scalar(ek, ev, self.current_step)
|
||||||
|
|
||||||
def do_training(self):
|
def do_training(self):
|
||||||
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
|
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user