|
|
|
@ -477,9 +477,7 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm
|
|
|
|
|
|
|
|
|
|
# superfluous, but it cleans up some things
|
|
|
|
|
class TrainingState():
|
|
|
|
|
def __init__(self, config_path, keep_x_past_datasets=0, start=True):
|
|
|
|
|
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', config_path]
|
|
|
|
|
|
|
|
|
|
def __init__(self, config_path, keep_x_past_datasets=0, start=True, gpus=1):
|
|
|
|
|
# parse config to get its iteration
|
|
|
|
|
with open(config_path, 'r') as file:
|
|
|
|
|
self.config = yaml.safe_load(file)
|
|
|
|
@ -530,9 +528,11 @@ class TrainingState():
|
|
|
|
|
if keep_x_past_datasets > 0:
|
|
|
|
|
self.cleanup_old(keep=keep_x_past_datasets)
|
|
|
|
|
if start:
|
|
|
|
|
self.spawn_process()
|
|
|
|
|
self.spawn_process(config_path=config_path, gpus=gpus)
|
|
|
|
|
|
|
|
|
|
def spawn_process(self, config_path, gpus=1):
|
|
|
|
|
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', str(int(gpus)), config_path]
|
|
|
|
|
|
|
|
|
|
def spawn_process(self):
|
|
|
|
|
print("Spawning process: ", " ".join(self.cmd))
|
|
|
|
|
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
|
|
|
|
|
|
|
|
@ -745,7 +745,7 @@ class TrainingState():
|
|
|
|
|
if should_return:
|
|
|
|
|
return "".join(self.buffer)
|
|
|
|
|
|
|
|
|
|
def run_training(config_path, verbose=False, buffer_size=8, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)):
|
|
|
|
|
def run_training(config_path, verbose=False, gpus=1, buffer_size=8, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)):
|
|
|
|
|
global training_state
|
|
|
|
|
if training_state and training_state.process:
|
|
|
|
|
return "Training already in progress"
|
|
|
|
@ -757,7 +757,7 @@ def run_training(config_path, verbose=False, buffer_size=8, keep_x_past_datasets
|
|
|
|
|
unload_whisper()
|
|
|
|
|
unload_voicefixer()
|
|
|
|
|
|
|
|
|
|
training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets)
|
|
|
|
|
training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets, gpus=gpus)
|
|
|
|
|
|
|
|
|
|
for line in iter(training_state.process.stdout.readline, ""):
|
|
|
|
|
|
|
|
|
@ -785,11 +785,13 @@ def update_training_dataplot(config_path=None):
|
|
|
|
|
update = None
|
|
|
|
|
|
|
|
|
|
if not training_state:
|
|
|
|
|
training_state = TrainingState(config_path=config_path, start=False)
|
|
|
|
|
update = gr.LinePlot.update(value=pd.DataFrame(training_state.losses))
|
|
|
|
|
del training_state
|
|
|
|
|
training_state = None
|
|
|
|
|
else:
|
|
|
|
|
if config_path:
|
|
|
|
|
training_state = TrainingState(config_path=config_path, start=False)
|
|
|
|
|
if training_state.losses:
|
|
|
|
|
update = gr.LinePlot.update(value=pd.DataFrame(training_state.losses))
|
|
|
|
|
del training_state
|
|
|
|
|
training_state = None
|
|
|
|
|
elif training_state.losses:
|
|
|
|
|
update = gr.LinePlot.update(value=pd.DataFrame(training_state.losses))
|
|
|
|
|
|
|
|
|
|
return update
|
|
|
|
@ -1285,12 +1287,13 @@ def setup_args():
|
|
|
|
|
if not args.device_override:
|
|
|
|
|
set_device_name(args.device_override)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
args.listen_host = None
|
|
|
|
|
args.listen_port = None
|
|
|
|
|
args.listen_path = None
|
|
|
|
|
if args.listen:
|
|
|
|
|
try:
|
|
|
|
|
match = re.findall(r"^(?:(.+?):(\d+))?(\/.+?)?$", args.listen)[0]
|
|
|
|
|
match = re.findall(r"^(?:(.+?):(\d+))?(\/.*?)?$", args.listen)[0]
|
|
|
|
|
|
|
|
|
|
args.listen_host = match[0] if match[0] != "" else "127.0.0.1"
|
|
|
|
|
args.listen_port = match[1] if match[1] != "" else None
|
|
|
|
|