1
0

experimental multi-gpu training (Linux only, because I can't into batch files)

This commit is contained in:
mrq 2023-03-03 04:37:18 +00:00
parent e205322c8d
commit e859a7c01d
5 changed files with 36 additions and 18 deletions

7
.gitignore vendored
View File

@ -1,7 +1,8 @@
# ignores user files # ignores user files
/tortoise-venv/ /venv/
/tortoise/voices/ /voices/*
/models/ /models/*
/training/*
/config/* /config/*
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files

View File

@ -18,9 +18,12 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml', nargs='+') # ugh parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml', nargs='+') # ugh
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, help='Rank Number')
args = parser.parse_args() args = parser.parse_args()
args.opt = " ".join(args.opt) # absolutely disgusting args.opt = " ".join(args.opt) # absolutely disgusting
os.environ['LOCAL_RANK'] = str(args.local_rank)
with open(args.opt, 'r') as file: with open(args.opt, 'r') as file:
opt_config = yaml.safe_load(file) opt_config = yaml.safe_load(file)
@ -71,7 +74,7 @@ def train(yaml, launcher='none'):
print('Disabled distributed training.') print('Disabled distributed training.')
else: else:
opt['dist'] = True opt['dist'] = True
init_dist('nccl') tr.init_dist('nccl')
trainer.world_size = torch.distributed.get_world_size() trainer.world_size = torch.distributed.get_world_size()
trainer.rank = torch.distributed.get_rank() trainer.rank = torch.distributed.get_rank()
torch.cuda.set_device(torch.distributed.get_rank()) torch.cuda.set_device(torch.distributed.get_rank())

View File

@ -477,9 +477,7 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm
# superfluous, but it cleans up some things # superfluous, but it cleans up some things
class TrainingState(): class TrainingState():
def __init__(self, config_path, keep_x_past_datasets=0, start=True): def __init__(self, config_path, keep_x_past_datasets=0, start=True, gpus=1):
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', config_path]
# parse config to get its iteration # parse config to get its iteration
with open(config_path, 'r') as file: with open(config_path, 'r') as file:
self.config = yaml.safe_load(file) self.config = yaml.safe_load(file)
@ -530,9 +528,11 @@ class TrainingState():
if keep_x_past_datasets > 0: if keep_x_past_datasets > 0:
self.cleanup_old(keep=keep_x_past_datasets) self.cleanup_old(keep=keep_x_past_datasets)
if start: if start:
self.spawn_process() self.spawn_process(config_path=config_path, gpus=gpus)
def spawn_process(self, config_path, gpus=1):
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', str(int(gpus)), config_path]
def spawn_process(self):
print("Spawning process: ", " ".join(self.cmd)) print("Spawning process: ", " ".join(self.cmd))
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True) self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
@ -745,7 +745,7 @@ class TrainingState():
if should_return: if should_return:
return "".join(self.buffer) return "".join(self.buffer)
def run_training(config_path, verbose=False, buffer_size=8, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)): def run_training(config_path, verbose=False, gpus=1, buffer_size=8, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)):
global training_state global training_state
if training_state and training_state.process: if training_state and training_state.process:
return "Training already in progress" return "Training already in progress"
@ -757,7 +757,7 @@ def run_training(config_path, verbose=False, buffer_size=8, keep_x_past_datasets
unload_whisper() unload_whisper()
unload_voicefixer() unload_voicefixer()
training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets) training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets, gpus=gpus)
for line in iter(training_state.process.stdout.readline, ""): for line in iter(training_state.process.stdout.readline, ""):
@ -785,11 +785,13 @@ def update_training_dataplot(config_path=None):
update = None update = None
if not training_state: if not training_state:
training_state = TrainingState(config_path=config_path, start=False) if config_path:
update = gr.LinePlot.update(value=pd.DataFrame(training_state.losses)) training_state = TrainingState(config_path=config_path, start=False)
del training_state if training_state.losses:
training_state = None update = gr.LinePlot.update(value=pd.DataFrame(training_state.losses))
else: del training_state
training_state = None
elif training_state.losses:
update = gr.LinePlot.update(value=pd.DataFrame(training_state.losses)) update = gr.LinePlot.update(value=pd.DataFrame(training_state.losses))
return update return update
@ -1285,12 +1287,13 @@ def setup_args():
if not args.device_override: if not args.device_override:
set_device_name(args.device_override) set_device_name(args.device_override)
args.listen_host = None args.listen_host = None
args.listen_port = None args.listen_port = None
args.listen_path = None args.listen_path = None
if args.listen: if args.listen:
try: try:
match = re.findall(r"^(?:(.+?):(\d+))?(\/.+?)?$", args.listen)[0] match = re.findall(r"^(?:(.+?):(\d+))?(\/.*?)?$", args.listen)[0]
args.listen_host = match[0] if match[0] != "" else "127.0.0.1" args.listen_host = match[0] if match[0] != "" else "127.0.0.1"
args.listen_port = match[1] if match[1] != "" else None args.listen_port = match[1] if match[1] != "" else None

View File

@ -546,6 +546,7 @@ def setup_gradio():
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True) verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8) training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8)
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1) training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
training_gpu_count = gr.Number(label="GPUs", value=1)
with gr.Row(): with gr.Row():
start_training_button = gr.Button(value="Train") start_training_button = gr.Button(value="Train")
stop_training_button = gr.Button(value="Stop") stop_training_button = gr.Button(value="Stop")
@ -751,6 +752,7 @@ def setup_gradio():
inputs=[ inputs=[
training_configs, training_configs,
verbose_training, verbose_training,
training_gpu_count,
training_buffer_size, training_buffer_size,
training_keep_x_past_datasets, training_keep_x_past_datasets,
], ],

View File

@ -1,4 +1,13 @@
#!/bin/bash #!/bin/bash
source ./venv/bin/activate source ./venv/bin/activate
python3 ./src/train.py -opt "$1"
GPUS=$1
CONFIG=$2
PORT=1234
if (( $GPUS > 1 )); then
python3 -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT ./src/train.py -opt "$CONFIG" --launcher=pytorch
else
python3 ./src/train.py -opt "$CONFIG"
fi
deactivate deactivate