Move train imports into init_dist
This commit is contained in:
parent
e9ee67ff10
commit
ea9c6765ca
|
@ -7,8 +7,6 @@ import shutil
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
#import torch.distributed as dist
|
|
||||||
#import torch.multiprocessing as mp
|
|
||||||
from data.data_sampler import DistIterSampler
|
from data.data_sampler import DistIterSampler
|
||||||
|
|
||||||
import options.options as option
|
import options.options as option
|
||||||
|
@ -19,6 +17,10 @@ from time import time
|
||||||
|
|
||||||
|
|
||||||
def init_dist(backend='nccl', **kwargs):
|
def init_dist(backend='nccl', **kwargs):
|
||||||
|
# These packages have globals that screw with Windows, so only import them if needed.
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
"""initialization for distributed training"""
|
"""initialization for distributed training"""
|
||||||
if mp.get_start_method(allow_none=True) != 'spawn':
|
if mp.get_start_method(allow_none=True) != 'spawn':
|
||||||
mp.set_start_method('spawn')
|
mp.set_start_method('spawn')
|
||||||
|
|
Loading…
Reference in New Issue
Block a user