forked from mrq/ai-voice-cloning
experimental multi-gpu training (Linux only, because I can't into batch files)
This commit is contained in:
parent
e205322c8d
commit
e859a7c01d
7
.gitignore
vendored
7
.gitignore
vendored
|
@ -1,7 +1,8 @@
|
|||
# ignores user files
|
||||
/tortoise-venv/
|
||||
/tortoise/voices/
|
||||
/models/
|
||||
/venv/
|
||||
/voices/*
|
||||
/models/*
|
||||
/training/*
|
||||
/config/*
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
|
|
|
@ -18,9 +18,12 @@ if __name__ == "__main__":
|
|||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml', nargs='+') # ugh
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, help='Rank Number')
|
||||
args = parser.parse_args()
|
||||
args.opt = " ".join(args.opt) # absolutely disgusting
|
||||
|
||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
|
||||
with open(args.opt, 'r') as file:
|
||||
opt_config = yaml.safe_load(file)
|
||||
|
||||
|
@ -71,7 +74,7 @@ def train(yaml, launcher='none'):
|
|||
print('Disabled distributed training.')
|
||||
else:
|
||||
opt['dist'] = True
|
||||
init_dist('nccl')
|
||||
tr.init_dist('nccl')
|
||||
trainer.world_size = torch.distributed.get_world_size()
|
||||
trainer.rank = torch.distributed.get_rank()
|
||||
torch.cuda.set_device(torch.distributed.get_rank())
|
||||
|
|
21
src/utils.py
21
src/utils.py
|
@ -477,9 +477,7 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm
|
|||
|
||||
# superfluous, but it cleans up some things
|
||||
class TrainingState():
|
||||
def __init__(self, config_path, keep_x_past_datasets=0, start=True):
|
||||
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', config_path]
|
||||
|
||||
def __init__(self, config_path, keep_x_past_datasets=0, start=True, gpus=1):
|
||||
# parse config to get its iteration
|
||||
with open(config_path, 'r') as file:
|
||||
self.config = yaml.safe_load(file)
|
||||
|
@ -530,9 +528,11 @@ class TrainingState():
|
|||
if keep_x_past_datasets > 0:
|
||||
self.cleanup_old(keep=keep_x_past_datasets)
|
||||
if start:
|
||||
self.spawn_process()
|
||||
self.spawn_process(config_path=config_path, gpus=gpus)
|
||||
|
||||
def spawn_process(self, config_path, gpus=1):
|
||||
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', str(int(gpus)), config_path]
|
||||
|
||||
def spawn_process(self):
|
||||
print("Spawning process: ", " ".join(self.cmd))
|
||||
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
||||
|
||||
|
@ -745,7 +745,7 @@ class TrainingState():
|
|||
if should_return:
|
||||
return "".join(self.buffer)
|
||||
|
||||
def run_training(config_path, verbose=False, buffer_size=8, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)):
|
||||
def run_training(config_path, verbose=False, gpus=1, buffer_size=8, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)):
|
||||
global training_state
|
||||
if training_state and training_state.process:
|
||||
return "Training already in progress"
|
||||
|
@ -757,7 +757,7 @@ def run_training(config_path, verbose=False, buffer_size=8, keep_x_past_datasets
|
|||
unload_whisper()
|
||||
unload_voicefixer()
|
||||
|
||||
training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets)
|
||||
training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets, gpus=gpus)
|
||||
|
||||
for line in iter(training_state.process.stdout.readline, ""):
|
||||
|
||||
|
@ -785,11 +785,13 @@ def update_training_dataplot(config_path=None):
|
|||
update = None
|
||||
|
||||
if not training_state:
|
||||
if config_path:
|
||||
training_state = TrainingState(config_path=config_path, start=False)
|
||||
if training_state.losses:
|
||||
update = gr.LinePlot.update(value=pd.DataFrame(training_state.losses))
|
||||
del training_state
|
||||
training_state = None
|
||||
else:
|
||||
elif training_state.losses:
|
||||
update = gr.LinePlot.update(value=pd.DataFrame(training_state.losses))
|
||||
|
||||
return update
|
||||
|
@ -1285,12 +1287,13 @@ def setup_args():
|
|||
if not args.device_override:
|
||||
set_device_name(args.device_override)
|
||||
|
||||
|
||||
args.listen_host = None
|
||||
args.listen_port = None
|
||||
args.listen_path = None
|
||||
if args.listen:
|
||||
try:
|
||||
match = re.findall(r"^(?:(.+?):(\d+))?(\/.+?)?$", args.listen)[0]
|
||||
match = re.findall(r"^(?:(.+?):(\d+))?(\/.*?)?$", args.listen)[0]
|
||||
|
||||
args.listen_host = match[0] if match[0] != "" else "127.0.0.1"
|
||||
args.listen_port = match[1] if match[1] != "" else None
|
||||
|
|
|
@ -546,6 +546,7 @@ def setup_gradio():
|
|||
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
|
||||
training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8)
|
||||
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
|
||||
training_gpu_count = gr.Number(label="GPUs", value=1)
|
||||
with gr.Row():
|
||||
start_training_button = gr.Button(value="Train")
|
||||
stop_training_button = gr.Button(value="Stop")
|
||||
|
@ -751,6 +752,7 @@ def setup_gradio():
|
|||
inputs=[
|
||||
training_configs,
|
||||
verbose_training,
|
||||
training_gpu_count,
|
||||
training_buffer_size,
|
||||
training_keep_x_past_datasets,
|
||||
],
|
||||
|
|
11
train.sh
11
train.sh
|
@ -1,4 +1,13 @@
|
|||
#!/bin/bash
|
||||
source ./venv/bin/activate
|
||||
python3 ./src/train.py -opt "$1"
|
||||
|
||||
GPUS=$1
|
||||
CONFIG=$2
|
||||
PORT=1234
|
||||
|
||||
if (( $GPUS > 1 )); then
|
||||
python3 -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT ./src/train.py -opt "$CONFIG" --launcher=pytorch
|
||||
else
|
||||
python3 ./src/train.py -opt "$CONFIG"
|
||||
fi
|
||||
deactivate
|
||||
|
|
Loading…
Reference in New Issue
Block a user