More adjustments to support distributed training with teco & on multi_modal_train

This commit is contained in:
James Betker 2020-10-27 20:58:03 -06:00
parent 5d09027ee2
commit da53090ce6
5 changed files with 88 additions and 76 deletions

View File

@ -30,7 +30,9 @@ class ExtensibleTrainer(BaseModel):
self.env = {'device': self.device, self.env = {'device': self.device,
'rank': self.rank, 'rank': self.rank,
'opt': opt, 'opt': opt,
'step': 0} 'step': 0,
'dist': opt['dist']
}
if opt['path']['models'] is not None: if opt['path']['models'] is not None:
self.env['base_path'] = os.path.join(opt['path']['models']) self.env['base_path'] = os.path.join(opt['path']['models'])

View File

@ -94,12 +94,12 @@ class SPSRNet(nn.Module):
n_upscale = int(math.log(upscale, 2)) n_upscale = int(math.log(upscale, 2))
self.scale=n_upscale self.scale=upscale
if upscale == 3: if upscale == 3:
n_upscale = 1 n_upscale = 1
fea_conv = ConvGnLelu(in_nc, nf//2, kernel_size=7, norm=False, activation=False) fea_conv = ConvGnLelu(in_nc, nf//2, kernel_size=7, norm=False, activation=False)
self.ref_conv = ConvGnLelu(in_nc, nf//2, stride=n_upscale, kernel_size=7, norm=False, activation=False) self.ref_conv = ConvGnLelu(in_nc, nf//2, stride=upscale, kernel_size=7, norm=False, activation=False)
self.join_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False) self.join_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
rb_blocks = [RRDB(nf) for _ in range(nb)] rb_blocks = [RRDB(nf) for _ in range(nb)]
@ -118,7 +118,7 @@ class SPSRNet(nn.Module):
*upsampler, self.HR_conv0_new) *upsampler, self.HR_conv0_new)
self.b_fea_conv = ConvGnLelu(in_nc, nf//2, kernel_size=3, norm=False, activation=False) self.b_fea_conv = ConvGnLelu(in_nc, nf//2, kernel_size=3, norm=False, activation=False)
self.b_ref_conv = ConvGnLelu(in_nc, nf//2, stride=n_upscale, kernel_size=3, norm=False, activation=False) self.b_ref_conv = ConvGnLelu(in_nc, nf//2, stride=upscale, kernel_size=3, norm=False, activation=False)
self.b_join_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False) self.b_join_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
self.b_concat_1 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False) self.b_concat_1 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False)
@ -657,4 +657,4 @@ class SwitchedSpsr(nn.Module):
for i in range(len(means)): for i in range(len(means)):
val["switch_%i_specificity" % (i,)] = means[i] val["switch_%i_specificity" % (i,)] = means[i]
val["switch_%i_histogram" % (i,)] = hists[i] val["switch_%i_histogram" % (i,)] = hists[i]
return val return val

View File

@ -126,49 +126,48 @@ class ConfigurableStep(Module):
self.env['current_step_optimizers'] = self.optimizers self.env['current_step_optimizers'] = self.optimizers
self.env['training'] = train self.env['training'] = train
with self.get_network_for_name(self.get_networks_trained()[0]).join(): # Inject in any extra dependencies.
# Inject in any extra dependencies. for inj in self.injectors:
for inj in self.injectors: # Don't do injections tagged with eval unless we are not in train mode.
# Don't do injections tagged with eval unless we are not in train mode. if train and 'eval' in inj.opt.keys() and inj.opt['eval']:
if train and 'eval' in inj.opt.keys() and inj.opt['eval']: continue
continue # Likewise, don't do injections tagged with train unless we are not in eval.
# Likewise, don't do injections tagged with train unless we are not in eval. if not train and 'train' in inj.opt.keys() and inj.opt['train']:
if not train and 'train' in inj.opt.keys() and inj.opt['train']: continue
continue # Don't do injections tagged with 'after' or 'before' when we are out of spec.
# Don't do injections tagged with 'after' or 'before' when we are out of spec. if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \
if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \ 'before' in inj.opt.keys() and self.env['step'] > inj.opt['before']:
'before' in inj.opt.keys() and self.env['step'] > inj.opt['before']: continue
continue injected = inj(local_state)
injected = inj(local_state) local_state.update(injected)
local_state.update(injected) new_state.update(injected)
new_state.update(injected)
if train and len(self.losses) > 0: if train and len(self.losses) > 0:
# Finally, compute the losses. # Finally, compute the losses.
total_loss = 0 total_loss = 0
for loss_name, loss in self.losses.items(): for loss_name, loss in self.losses.items():
# Some losses only activate after a set number of steps. For example, proto-discriminator losses can # Some losses only activate after a set number of steps. For example, proto-discriminator losses can
# be very disruptive to a generator. # be very disruptive to a generator.
if 'after' in loss.opt.keys() and loss.opt['after'] > self.env['step']: if 'after' in loss.opt.keys() and loss.opt['after'] > self.env['step']:
continue continue
l = loss(self.training_net, local_state) l = loss(self.training_net, local_state)
total_loss += l * self.weights[loss_name] total_loss += l * self.weights[loss_name]
# Record metrics. # Record metrics.
if isinstance(l, torch.Tensor): if isinstance(l, torch.Tensor):
self.loss_accumulator.add_loss(loss_name, l) self.loss_accumulator.add_loss(loss_name, l)
for n, v in loss.extra_metrics(): for n, v in loss.extra_metrics():
self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v) self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v)
loss.clear_metrics() loss.clear_metrics()
# In some cases, the loss could not be set (e.g. all losses have 'after') # In some cases, the loss could not be set (e.g. all losses have 'after')
if isinstance(total_loss, torch.Tensor): if isinstance(total_loss, torch.Tensor):
self.loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss) self.loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss)
# Scale the loss down by the accumulation factor. # Scale the loss down by the accumulation factor.
total_loss = total_loss / self.env['mega_batch_factor'] total_loss = total_loss / self.env['mega_batch_factor']
# Get dem grads! # Get dem grads!
self.scaler.scale(total_loss).backward() self.scaler.scale(total_loss).backward()
self.grads_generated = True self.grads_generated = True
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step # Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
# we must release the gradients. # we must release the gradients.

View File

@ -15,15 +15,29 @@ import yaml
import train import train
import utils.options as option import utils.options as option
from utils.util import OrderedYaml from utils.util import OrderedYaml
import torch
def main(master_opt, launcher): def main(master_opt, launcher):
trainers = [] trainers = []
all_networks = {} all_networks = {}
shared_networks = [] shared_networks = []
if launcher != 'none':
train.init_dist('nccl')
for i, sub_opt in enumerate(master_opt['trainer_options']): for i, sub_opt in enumerate(master_opt['trainer_options']):
sub_opt_parsed = option.parse(sub_opt, is_train=True) sub_opt_parsed = option.parse(sub_opt, is_train=True)
trainer = train.Trainer() trainer = train.Trainer()
#### distributed training settings
if launcher == 'none': # disabled distributed training
sub_opt_parsed['dist'] = False
trainer.rank = -1
print('Disabled distributed training.')
else:
sub_opt_parsed['dist'] = True
trainer.world_size = torch.distributed.get_world_size()
trainer.rank = torch.distributed.get_rank()
trainer.init(sub_opt_parsed, launcher, all_networks) trainer.init(sub_opt_parsed, launcher, all_networks)
train_gen = trainer.create_training_generator(i) train_gen = trainer.create_training_generator(i)
model = next(train_gen) model = next(train_gen)
@ -44,6 +58,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_chained_structured_trans_invariance.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_chained_structured_trans_invariance.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()
Loader, Dumper = OrderedYaml() Loader, Dumper = OrderedYaml()

View File

@ -13,43 +13,26 @@ from data import create_dataloader, create_dataset
from models.ExtensibleTrainer import ExtensibleTrainer from models.ExtensibleTrainer import ExtensibleTrainer
from time import time from time import time
class Trainer: def init_dist(backend, **kwargs):
def init_dist(self, backend, **kwargs): # These packages have globals that screw with Windows, so only import them if needed.
# These packages have globals that screw with Windows, so only import them if needed. import torch.distributed as dist
import torch.distributed as dist import torch.multiprocessing as mp
import torch.multiprocessing as mp
"""initialization for distributed training""" """initialization for distributed training"""
if mp.get_start_method(allow_none=True) != 'spawn': if mp.get_start_method(allow_none=True) != 'spawn':
mp.set_start_method('spawn') mp.set_start_method('spawn')
self.rank = int(os.environ['RANK']) rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count() num_gpus = torch.cuda.device_count()
torch.cuda.set_device(self.rank % num_gpus) torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs) dist.init_process_group(backend=backend, **kwargs)
class Trainer:
def init(self, opt, launcher, all_networks={}): def init(self, opt, launcher, all_networks={}):
self._profile = False self._profile = False
self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'] else True self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'] else True
self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'] else True self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'] else True
#### distributed training settings
if len(opt['gpu_ids']) == 1 and torch.cuda.device_count() > 1:
gpu = input(
'I noticed you have multiple GPUs. Starting two jobs on the same GPU sucks. Please confirm which GPU'
'you want to use. Press enter to use the specified one [%s]' % (opt['gpu_ids']))
if gpu:
opt['gpu_ids'] = [int(gpu)]
if launcher == 'none': # disabled distributed training
opt['dist'] = False
self.rank = -1
print('Disabled distributed training.')
else:
opt['dist'] = True
self.init_dist('nccl')
world_size = torch.distributed.get_world_size()
self.rank = torch.distributed.get_rank()
#### loading resume state if exists #### loading resume state if exists
if opt['path'].get('resume_state', None): if opt['path'].get('resume_state', None):
# distributed resuming: all load into default GPU # distributed resuming: all load into default GPU
@ -117,7 +100,7 @@ class Trainer:
total_iters = int(opt['train']['niter']) total_iters = int(opt['train']['niter'])
self.total_epochs = int(math.ceil(total_iters / train_size)) self.total_epochs = int(math.ceil(total_iters / train_size))
if opt['dist']: if opt['dist']:
self.train_sampler = DistIterSampler(self.train_set, world_size, self.rank, dataset_ratio) self.train_sampler = DistIterSampler(self.train_set, self.world_size, self.rank, dataset_ratio)
self.total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio))) self.total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
else: else:
self.train_sampler = None self.train_sampler = None
@ -288,5 +271,18 @@ if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
opt = option.parse(args.opt, is_train=True) opt = option.parse(args.opt, is_train=True)
trainer = Trainer() trainer = Trainer()
#### distributed training settings
if args.launcher == 'none': # disabled distributed training
opt['dist'] = False
trainer.rank = -1
print('Disabled distributed training.')
else:
opt['dist'] = True
init_dist('nccl')
trainer.world_size = torch.distributed.get_world_size()
trainer.rank = torch.distributed.get_rank()
trainer.init(opt, args.launcher) trainer.init(opt, args.launcher)
trainer.do_training() trainer.do_training()