forked from mrq/ai-voice-cloning
use torchrun instead for multigpu
This commit is contained in:
parent
5026d93ecd
commit
37cab14272
|
@ -18,12 +18,9 @@ if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml', nargs='+') # ugh
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml', nargs='+') # ugh
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, help='Rank Number')
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.opt = " ".join(args.opt) # absolutely disgusting
|
args.opt = " ".join(args.opt) # absolutely disgusting
|
||||||
|
|
||||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
|
||||||
|
|
||||||
with open(args.opt, 'r') as file:
|
with open(args.opt, 'r') as file:
|
||||||
opt_config = yaml.safe_load(file)
|
opt_config = yaml.safe_load(file)
|
||||||
|
|
||||||
|
|
2
train.sh
2
train.sh
|
@ -6,7 +6,7 @@ CONFIG=$2
|
||||||
PORT=1234
|
PORT=1234
|
||||||
|
|
||||||
if (( $GPUS > 1 )); then
|
if (( $GPUS > 1 )); then
|
||||||
python3 -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT ./src/train.py -opt "$CONFIG" --launcher=pytorch
|
torchrun --nproc_per_node=$GPUS --master_port=$PORT ./src/train.py -opt "$CONFIG" --launcher=pytorch
|
||||||
else
|
else
|
||||||
python3 ./src/train.py -opt "$CONFIG"
|
python3 ./src/train.py -opt "$CONFIG"
|
||||||
fi
|
fi
|
||||||
|
|
Loading…
Reference in New Issue
Block a user