Distributed "fixes"
This commit is contained in:
parent
3ff878ae85
commit
ce6dfdf255
|
@ -23,14 +23,10 @@ from utils.util import opt_get, map_cuda_to_correct_device
|
|||
def init_dist(backend, **kwargs):
|
||||
# These packages have globals that screw with Windows, so only import them if needed.
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
"""initialization for distributed training"""
|
||||
if mp.get_start_method(allow_none=True) != 'spawn':
|
||||
mp.set_start_method('spawn')
|
||||
rank = int(os.environ['RANK'])
|
||||
num_gpus = torch.cuda.device_count()
|
||||
torch.cuda.set_device(rank % num_gpus)
|
||||
rank = int(os.environ['LOCAL_RANK'])
|
||||
assert rank < torch.cuda.device_count()
|
||||
torch.cuda.set_device(rank)
|
||||
dist.init_process_group(backend=backend, **kwargs)
|
||||
|
||||
class Trainer:
|
||||
|
@ -321,7 +317,6 @@ if __name__ == '__main__':
|
|||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_wav2vec_matcher.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
if args.launcher != 'none':
|
||||
|
@ -344,6 +339,7 @@ if __name__ == '__main__':
|
|||
init_dist('nccl')
|
||||
trainer.world_size = torch.distributed.get_world_size()
|
||||
trainer.rank = torch.distributed.get_rank()
|
||||
torch.cuda.set_device(torch.distributed.get_rank())
|
||||
|
||||
trainer.init(args.opt, opt, args.launcher)
|
||||
trainer.do_training()
|
||||
|
|
|
@ -139,7 +139,9 @@ class ExtensibleTrainer(BaseModel):
|
|||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
# Do NOT be tempted to put find_unused_parameters=True here. It will not work when checkpointing is
|
||||
# used and in a few other cases. But you can try it if you really want.
|
||||
dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()], find_unused_parameters=opt_get(opt, ['ddp_find_unused_parameters'], False))
|
||||
dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()],
|
||||
output_device=torch.cuda.current_device(),
|
||||
find_unused_parameters=opt_get(opt, ['ddp_find_unused_parameters'], False))
|
||||
# DDP graphs cannot be used with gradient checkpointing unless you use find_unused_parameters=True,
|
||||
# which does not work with this trainer (as stated above). However, if the graph is not subject
|
||||
# to control flow alterations, you can set this option to allow gradient checkpointing. Beware that
|
||||
|
|
|
@ -16,7 +16,7 @@ class BaseModel():
|
|||
self.rank = torch.distributed.get_rank()
|
||||
else:
|
||||
self.rank = -1 # non dist training
|
||||
self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
|
||||
self.device = torch.cuda.current_device() if opt['gpu_ids'] else torch.device('cpu')
|
||||
self.amp_level = 'O0' if opt['amp_opt_level'] is None else opt['amp_opt_level']
|
||||
self.is_train = opt['is_train']
|
||||
self.opt_in_cpu = opt_get(opt, ['keep_optimizer_states_on_cpu'], False)
|
||||
|
|
Loading…
Reference in New Issue
Block a user