forked from mrq/ai-voice-cloning
use torchrun instead for multigpu
This commit is contained in:
parent
5026d93ecd
commit
37cab14272
|
@ -18,12 +18,9 @@ if __name__ == "__main__":
|
|||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml', nargs='+') # ugh
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, help='Rank Number')
|
||||
args = parser.parse_args()
|
||||
args.opt = " ".join(args.opt) # absolutely disgusting
|
||||
|
||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
|
||||
with open(args.opt, 'r') as file:
|
||||
opt_config = yaml.safe_load(file)
|
||||
|
||||
|
|
2
train.sh
2
train.sh
|
@ -6,7 +6,7 @@ CONFIG=$2
|
|||
PORT=1234
|
||||
|
||||
if (( $GPUS > 1 )); then
|
||||
python3 -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT ./src/train.py -opt "$CONFIG" --launcher=pytorch
|
||||
torchrun --nproc_per_node=$GPUS --master_port=$PORT ./src/train.py -opt "$CONFIG" --launcher=pytorch
|
||||
else
|
||||
python3 ./src/train.py -opt "$CONFIG"
|
||||
fi
|
||||
|
|
Loading…
Reference in New Issue
Block a user