diff --git a/codes/train.py b/codes/train.py index 7157ddad..b2b87afc 100644 --- a/codes/train.py +++ b/codes/train.py @@ -7,8 +7,6 @@ import shutil from tqdm import tqdm import torch -#import torch.distributed as dist -#import torch.multiprocessing as mp from data.data_sampler import DistIterSampler import options.options as option @@ -19,6 +17,10 @@ from time import time def init_dist(backend='nccl', **kwargs): + # These packages have globals that screw with Windows, so only import them if needed. + import torch.distributed as dist + import torch.multiprocessing as mp + """initialization for distributed training""" if mp.get_start_method(allow_none=True) != 'spawn': mp.set_start_method('spawn')