1
0
Fork 0

experimental multi-gpu training (Linux only, because I can't into batch files)

remotes/1708699347150643056/master
mrq 2023-03-03 04:37:18 +07:00
parent e205322c8d
commit e859a7c01d
5 changed files with 36 additions and 18 deletions

7
.gitignore vendored

@ -1,7 +1,8 @@
# ignores user files
/tortoise-venv/
/tortoise/voices/
/models/
/venv/
/voices/*
/models/*
/training/*
/config/*
# Byte-compiled / optimized / DLL files

@ -18,9 +18,12 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml', nargs='+') # ugh
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, help='Rank Number')
args = parser.parse_args()
args.opt = " ".join(args.opt) # absolutely disgusting
os.environ['LOCAL_RANK'] = str(args.local_rank)
with open(args.opt, 'r') as file:
opt_config = yaml.safe_load(file)
@ -71,7 +74,7 @@ def train(yaml, launcher='none'):
print('Disabled distributed training.')
else:
opt['dist'] = True
init_dist('nccl')
tr.init_dist('nccl')
trainer.world_size = torch.distributed.get_world_size()
trainer.rank = torch.distributed.get_rank()
torch.cuda.set_device(torch.distributed.get_rank())

@ -477,9 +477,7 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm
# superfluous, but it cleans up some things
class TrainingState():
def __init__(self, config_path, keep_x_past_datasets=0, start=True):
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', config_path]
def __init__(self, config_path, keep_x_past_datasets=0, start=True, gpus=1):
# parse config to get its iteration
with open(config_path, 'r') as file:
self.config = yaml.safe_load(file)
@ -530,9 +528,11 @@ class TrainingState():
if keep_x_past_datasets > 0:
self.cleanup_old(keep=keep_x_past_datasets)
if start:
self.spawn_process()
self.spawn_process(config_path=config_path, gpus=gpus)
def spawn_process(self, config_path, gpus=1):
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', str(int(gpus)), config_path]
def spawn_process(self):
print("Spawning process: ", " ".join(self.cmd))
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
@ -745,7 +745,7 @@ class TrainingState():
if should_return:
return "".join(self.buffer)
def run_training(config_path, verbose=False, buffer_size=8, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)):
def run_training(config_path, verbose=False, gpus=1, buffer_size=8, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)):
global training_state
if training_state and training_state.process:
return "Training already in progress"
@ -757,7 +757,7 @@ def run_training(config_path, verbose=False, buffer_size=8, keep_x_past_datasets
unload_whisper()
unload_voicefixer()
training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets)
training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets, gpus=gpus)
for line in iter(training_state.process.stdout.readline, ""):
@ -785,11 +785,13 @@ def update_training_dataplot(config_path=None):
update = None
if not training_state:
training_state = TrainingState(config_path=config_path, start=False)
update = gr.LinePlot.update(value=pd.DataFrame(training_state.losses))
del training_state
training_state = None
else:
if config_path:
training_state = TrainingState(config_path=config_path, start=False)
if training_state.losses:
update = gr.LinePlot.update(value=pd.DataFrame(training_state.losses))
del training_state
training_state = None
elif training_state.losses:
update = gr.LinePlot.update(value=pd.DataFrame(training_state.losses))
return update
@ -1285,12 +1287,13 @@ def setup_args():
if not args.device_override:
set_device_name(args.device_override)
args.listen_host = None
args.listen_port = None
args.listen_path = None
if args.listen:
try:
match = re.findall(r"^(?:(.+?):(\d+))?(\/.+?)?$", args.listen)[0]
match = re.findall(r"^(?:(.+?):(\d+))?(\/.*?)?$", args.listen)[0]
args.listen_host = match[0] if match[0] != "" else "127.0.0.1"
args.listen_port = match[1] if match[1] != "" else None

@ -546,6 +546,7 @@ def setup_gradio():
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8)
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
training_gpu_count = gr.Number(label="GPUs", value=1)
with gr.Row():
start_training_button = gr.Button(value="Train")
stop_training_button = gr.Button(value="Stop")
@ -751,6 +752,7 @@ def setup_gradio():
inputs=[
training_configs,
verbose_training,
training_gpu_count,
training_buffer_size,
training_keep_x_past_datasets,
],

@ -1,4 +1,13 @@
#!/bin/bash
source ./venv/bin/activate
python3 ./src/train.py -opt "$1"
GPUS=$1
CONFIG=$2
PORT=1234
if (( $GPUS > 1 )); then
python3 -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT ./src/train.py -opt "$CONFIG" --launcher=pytorch
else
python3 ./src/train.py -opt "$CONFIG"
fi
deactivate