Distributed "fixes"

This commit is contained in:
James Betker 2022-03-04 12:46:41 -07:00
parent 3ff878ae85
commit ce6dfdf255
3 changed files with 8 additions and 10 deletions

View File

@ -23,14 +23,10 @@ from utils.util import opt_get, map_cuda_to_correct_device
def init_dist(backend, **kwargs):
# These packages have globals that screw with Windows, so only import them if needed.
import torch.distributed as dist
import torch.multiprocessing as mp
"""initialization for distributed training"""
if mp.get_start_method(allow_none=True) != 'spawn':
mp.set_start_method('spawn')
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
rank = int(os.environ['LOCAL_RANK'])
assert rank < torch.cuda.device_count()
torch.cuda.set_device(rank)
dist.init_process_group(backend=backend, **kwargs)
class Trainer:
@ -321,7 +317,6 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_wav2vec_matcher.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)
if args.launcher != 'none':
@ -344,6 +339,7 @@ if __name__ == '__main__':
init_dist('nccl')
trainer.world_size = torch.distributed.get_world_size()
trainer.rank = torch.distributed.get_rank()
torch.cuda.set_device(torch.distributed.get_rank())
trainer.init(args.opt, opt, args.launcher)
trainer.do_training()

View File

@ -139,7 +139,9 @@ class ExtensibleTrainer(BaseModel):
from torch.nn.parallel.distributed import DistributedDataParallel
# Do NOT be tempted to put find_unused_parameters=True here. It will not work when checkpointing is
# used and in a few other cases. But you can try it if you really want.
dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()], find_unused_parameters=opt_get(opt, ['ddp_find_unused_parameters'], False))
dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()],
output_device=torch.cuda.current_device(),
find_unused_parameters=opt_get(opt, ['ddp_find_unused_parameters'], False))
# DDP graphs cannot be used with gradient checkpointing unless you use find_unused_parameters=True,
# which does not work with this trainer (as stated above). However, if the graph is not subject
# to control flow alterations, you can set this option to allow gradient checkpointing. Beware that

View File

@ -16,7 +16,7 @@ class BaseModel():
self.rank = torch.distributed.get_rank()
else:
self.rank = -1 # non dist training
self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
self.device = torch.cuda.current_device() if opt['gpu_ids'] else torch.device('cpu')
self.amp_level = 'O0' if opt['amp_opt_level'] is None else opt['amp_opt_level']
self.is_train = opt['is_train']
self.opt_in_cpu = opt_get(opt, ['keep_optimizer_states_on_cpu'], False)