Fix for train.py

This commit is contained in:
James Betker 2021-01-01 11:59:00 -07:00
parent e214e6ce33
commit 9864fe4c04
2 changed files with 3 additions and 3 deletions

View File

@ -293,7 +293,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_xx_faces_glean.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_faces_styled_sr.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()
@ -306,7 +306,7 @@ if __name__ == '__main__':
print('export CUDA_VISIBLE_DEVICES=' + gpu_list) print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
trainer = Trainer() trainer = Trainer()
#### distributed training settings #### distributed training settings
if args.launcher == 'none': # disabled distributed training if args.launcher == 'none': # disabled distributed training
opt['dist'] = False opt['dist'] = False
trainer.rank = -1 trainer.rank = -1
@ -315,7 +315,7 @@ if __name__ == '__main__':
print('Disabled distributed training.') print('Disabled distributed training.')
else: else:
opt['dist'] = True opt['dist'] = True
init_dist('nccl', opt) init_dist('nccl')
trainer.world_size = torch.distributed.get_world_size() trainer.world_size = torch.distributed.get_world_size()
trainer.rank = torch.distributed.get_rank() trainer.rank = torch.distributed.get_rank()