Distributed "fixes"

This commit is contained in:
James Betker 2022-03-04 12:46:41 -07:00
parent 3ff878ae85
commit ce6dfdf255
3 changed files with 8 additions and 10 deletions

View File

@ -23,14 +23,10 @@ from utils.util import opt_get, map_cuda_to_correct_device
def init_dist(backend, **kwargs): def init_dist(backend, **kwargs):
# These packages have globals that screw with Windows, so only import them if needed. # These packages have globals that screw with Windows, so only import them if needed.
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp
"""initialization for distributed training""" rank = int(os.environ['LOCAL_RANK'])
if mp.get_start_method(allow_none=True) != 'spawn': assert rank < torch.cuda.device_count()
mp.set_start_method('spawn') torch.cuda.set_device(rank)
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs) dist.init_process_group(backend=backend, **kwargs)
class Trainer: class Trainer:
@ -321,7 +317,6 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_wav2vec_matcher.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_wav2vec_matcher.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()
opt = option.parse(args.opt, is_train=True) opt = option.parse(args.opt, is_train=True)
if args.launcher != 'none': if args.launcher != 'none':
@ -344,6 +339,7 @@ if __name__ == '__main__':
init_dist('nccl') init_dist('nccl')
trainer.world_size = torch.distributed.get_world_size() trainer.world_size = torch.distributed.get_world_size()
trainer.rank = torch.distributed.get_rank() trainer.rank = torch.distributed.get_rank()
torch.cuda.set_device(torch.distributed.get_rank())
trainer.init(args.opt, opt, args.launcher) trainer.init(args.opt, opt, args.launcher)
trainer.do_training() trainer.do_training()

View File

@ -139,7 +139,9 @@ class ExtensibleTrainer(BaseModel):
from torch.nn.parallel.distributed import DistributedDataParallel from torch.nn.parallel.distributed import DistributedDataParallel
# Do NOT be tempted to put find_unused_parameters=True here. It will not work when checkpointing is # Do NOT be tempted to put find_unused_parameters=True here. It will not work when checkpointing is
# used and in a few other cases. But you can try it if you really want. # used and in a few other cases. But you can try it if you really want.
dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()], find_unused_parameters=opt_get(opt, ['ddp_find_unused_parameters'], False)) dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()],
output_device=torch.cuda.current_device(),
find_unused_parameters=opt_get(opt, ['ddp_find_unused_parameters'], False))
# DDP graphs cannot be used with gradient checkpointing unless you use find_unused_parameters=True, # DDP graphs cannot be used with gradient checkpointing unless you use find_unused_parameters=True,
# which does not work with this trainer (as stated above). However, if the graph is not subject # which does not work with this trainer (as stated above). However, if the graph is not subject
# to control flow alterations, you can set this option to allow gradient checkpointing. Beware that # to control flow alterations, you can set this option to allow gradient checkpointing. Beware that

View File

@ -16,7 +16,7 @@ class BaseModel():
self.rank = torch.distributed.get_rank() self.rank = torch.distributed.get_rank()
else: else:
self.rank = -1 # non dist training self.rank = -1 # non dist training
self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu') self.device = torch.cuda.current_device() if opt['gpu_ids'] else torch.device('cpu')
self.amp_level = 'O0' if opt['amp_opt_level'] is None else opt['amp_opt_level'] self.amp_level = 'O0' if opt['amp_opt_level'] is None else opt['amp_opt_level']
self.is_train = opt['is_train'] self.is_train = opt['is_train']
self.opt_in_cpu = opt_get(opt, ['keep_optimizer_states_on_cpu'], False) self.opt_in_cpu = opt_get(opt, ['keep_optimizer_states_on_cpu'], False)