Distributed "fixes"
This commit is contained in:
parent
3ff878ae85
commit
ce6dfdf255
|
@ -23,14 +23,10 @@ from utils.util import opt_get, map_cuda_to_correct_device
|
||||||
def init_dist(backend, **kwargs):
|
def init_dist(backend, **kwargs):
|
||||||
# These packages have globals that screw with Windows, so only import them if needed.
|
# These packages have globals that screw with Windows, so only import them if needed.
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.multiprocessing as mp
|
|
||||||
|
|
||||||
"""initialization for distributed training"""
|
rank = int(os.environ['LOCAL_RANK'])
|
||||||
if mp.get_start_method(allow_none=True) != 'spawn':
|
assert rank < torch.cuda.device_count()
|
||||||
mp.set_start_method('spawn')
|
torch.cuda.set_device(rank)
|
||||||
rank = int(os.environ['RANK'])
|
|
||||||
num_gpus = torch.cuda.device_count()
|
|
||||||
torch.cuda.set_device(rank % num_gpus)
|
|
||||||
dist.init_process_group(backend=backend, **kwargs)
|
dist.init_process_group(backend=backend, **kwargs)
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
|
@ -321,7 +317,6 @@ if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_wav2vec_matcher.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_wav2vec_matcher.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
opt = option.parse(args.opt, is_train=True)
|
opt = option.parse(args.opt, is_train=True)
|
||||||
if args.launcher != 'none':
|
if args.launcher != 'none':
|
||||||
|
@ -344,6 +339,7 @@ if __name__ == '__main__':
|
||||||
init_dist('nccl')
|
init_dist('nccl')
|
||||||
trainer.world_size = torch.distributed.get_world_size()
|
trainer.world_size = torch.distributed.get_world_size()
|
||||||
trainer.rank = torch.distributed.get_rank()
|
trainer.rank = torch.distributed.get_rank()
|
||||||
|
torch.cuda.set_device(torch.distributed.get_rank())
|
||||||
|
|
||||||
trainer.init(args.opt, opt, args.launcher)
|
trainer.init(args.opt, opt, args.launcher)
|
||||||
trainer.do_training()
|
trainer.do_training()
|
||||||
|
|
|
@ -139,7 +139,9 @@ class ExtensibleTrainer(BaseModel):
|
||||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||||
# Do NOT be tempted to put find_unused_parameters=True here. It will not work when checkpointing is
|
# Do NOT be tempted to put find_unused_parameters=True here. It will not work when checkpointing is
|
||||||
# used and in a few other cases. But you can try it if you really want.
|
# used and in a few other cases. But you can try it if you really want.
|
||||||
dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()], find_unused_parameters=opt_get(opt, ['ddp_find_unused_parameters'], False))
|
dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()],
|
||||||
|
output_device=torch.cuda.current_device(),
|
||||||
|
find_unused_parameters=opt_get(opt, ['ddp_find_unused_parameters'], False))
|
||||||
# DDP graphs cannot be used with gradient checkpointing unless you use find_unused_parameters=True,
|
# DDP graphs cannot be used with gradient checkpointing unless you use find_unused_parameters=True,
|
||||||
# which does not work with this trainer (as stated above). However, if the graph is not subject
|
# which does not work with this trainer (as stated above). However, if the graph is not subject
|
||||||
# to control flow alterations, you can set this option to allow gradient checkpointing. Beware that
|
# to control flow alterations, you can set this option to allow gradient checkpointing. Beware that
|
||||||
|
|
|
@ -16,7 +16,7 @@ class BaseModel():
|
||||||
self.rank = torch.distributed.get_rank()
|
self.rank = torch.distributed.get_rank()
|
||||||
else:
|
else:
|
||||||
self.rank = -1 # non dist training
|
self.rank = -1 # non dist training
|
||||||
self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
|
self.device = torch.cuda.current_device() if opt['gpu_ids'] else torch.device('cpu')
|
||||||
self.amp_level = 'O0' if opt['amp_opt_level'] is None else opt['amp_opt_level']
|
self.amp_level = 'O0' if opt['amp_opt_level'] is None else opt['amp_opt_level']
|
||||||
self.is_train = opt['is_train']
|
self.is_train = opt['is_train']
|
||||||
self.opt_in_cpu = opt_get(opt, ['keep_optimizer_states_on_cpu'], False)
|
self.opt_in_cpu = opt_get(opt, ['keep_optimizer_states_on_cpu'], False)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user