forked from mrq/ai-voice-cloning
added helper script to cull short enough lines from training set as a validation set (if it yields good results doing validation during training, i'll add it to the web ui)
This commit is contained in:
parent
7f89e8058a
commit
fe8bf7a9d1
35
src/cull_dataset.py
Executable file
35
src/cull_dataset.py
Executable file
|
@ -0,0 +1,35 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
indir = f'./training/{sys.argv[1]}/'
|
||||||
|
cap = int(sys.argv[2])
|
||||||
|
|
||||||
|
if not os.path.isdir(indir):
|
||||||
|
raise Exception(f"Invalid directory: {indir}")
|
||||||
|
|
||||||
|
if not os.path.exists(f'{indir}/train.txt'):
|
||||||
|
raise Exception(f"Missing dataset: {indir}/train.txt")
|
||||||
|
|
||||||
|
with open(f'{indir}/train.txt', 'r', encoding="utf-8") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
validation = []
|
||||||
|
training = []
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
split = line.split("|")
|
||||||
|
filename = split[0]
|
||||||
|
text = split[1]
|
||||||
|
|
||||||
|
if len(text) < cap:
|
||||||
|
validation.append(line.strip())
|
||||||
|
else:
|
||||||
|
training.append(line.strip())
|
||||||
|
|
||||||
|
with open(f'{indir}/train_culled.txt', 'w', encoding="utf-8") as f:
|
||||||
|
f.write("\n".join(training))
|
||||||
|
|
||||||
|
with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f:
|
||||||
|
f.write("\n".join(validation))
|
||||||
|
|
||||||
|
print(f"Culled {len(validation)} lines")
|
18
src/utils.py
18
src/utils.py
|
@ -605,7 +605,7 @@ def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, prog
|
||||||
|
|
||||||
# superfluous, but it cleans up some things
|
# superfluous, but it cleans up some things
|
||||||
class TrainingState():
|
class TrainingState():
|
||||||
def __init__(self, config_path, keep_x_past_datasets=0, start=True, gpus=1):
|
def __init__(self, config_path, keep_x_past_checkpoints=0, start=True, gpus=1):
|
||||||
# parse config to get its iteration
|
# parse config to get its iteration
|
||||||
with open(config_path, 'r') as file:
|
with open(config_path, 'r') as file:
|
||||||
self.config = yaml.safe_load(file)
|
self.config = yaml.safe_load(file)
|
||||||
|
@ -664,8 +664,8 @@ class TrainingState():
|
||||||
self.loss_milestones = [ 1.0, 0.15, 0.05 ]
|
self.loss_milestones = [ 1.0, 0.15, 0.05 ]
|
||||||
|
|
||||||
self.load_losses()
|
self.load_losses()
|
||||||
if keep_x_past_datasets > 0:
|
if keep_x_past_checkpoints > 0:
|
||||||
self.cleanup_old(keep=keep_x_past_datasets)
|
self.cleanup_old(keep=keep_x_past_checkpoints)
|
||||||
if start:
|
if start:
|
||||||
self.spawn_process(config_path=config_path, gpus=gpus)
|
self.spawn_process(config_path=config_path, gpus=gpus)
|
||||||
|
|
||||||
|
@ -772,7 +772,7 @@ class TrainingState():
|
||||||
print("Removing", path)
|
print("Removing", path)
|
||||||
os.remove(path)
|
os.remove(path)
|
||||||
|
|
||||||
def parse(self, line, verbose=False, keep_x_past_datasets=0, buffer_size=8, progress=None ):
|
def parse(self, line, verbose=False, keep_x_past_checkpoints=0, buffer_size=8, progress=None ):
|
||||||
self.buffer.append(f'{line}')
|
self.buffer.append(f'{line}')
|
||||||
|
|
||||||
should_return = False
|
should_return = False
|
||||||
|
@ -830,7 +830,7 @@ class TrainingState():
|
||||||
print(f'{"{:.3f}".format(percent*100)}% {message}')
|
print(f'{"{:.3f}".format(percent*100)}% {message}')
|
||||||
self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}')
|
self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}')
|
||||||
|
|
||||||
self.cleanup_old(keep=keep_x_past_datasets)
|
self.cleanup_old(keep=keep_x_past_checkpoints)
|
||||||
|
|
||||||
if line.find('%|') > 0:
|
if line.find('%|') > 0:
|
||||||
match = re.findall(r'(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line)
|
match = re.findall(r'(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line)
|
||||||
|
@ -986,7 +986,7 @@ class TrainingState():
|
||||||
message,
|
message,
|
||||||
)
|
)
|
||||||
|
|
||||||
def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)):
|
def run_training(config_path, verbose=False, gpus=1, keep_x_past_checkpoints=0, progress=gr.Progress(track_tqdm=True)):
|
||||||
global training_state
|
global training_state
|
||||||
if training_state and training_state.process:
|
if training_state and training_state.process:
|
||||||
return "Training already in progress"
|
return "Training already in progress"
|
||||||
|
@ -1008,13 +1008,13 @@ def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, pro
|
||||||
unload_whisper()
|
unload_whisper()
|
||||||
unload_voicefixer()
|
unload_voicefixer()
|
||||||
|
|
||||||
training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets, gpus=gpus)
|
training_state = TrainingState(config_path=config_path, keep_x_past_checkpoints=keep_x_past_checkpoints, gpus=gpus)
|
||||||
|
|
||||||
for line in iter(training_state.process.stdout.readline, ""):
|
for line in iter(training_state.process.stdout.readline, ""):
|
||||||
if training_state.killed:
|
if training_state.killed:
|
||||||
return
|
return
|
||||||
|
|
||||||
result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_datasets=keep_x_past_datasets, progress=progress )
|
result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_checkpoints=keep_x_past_checkpoints, progress=progress )
|
||||||
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
|
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
|
||||||
if result:
|
if result:
|
||||||
yield result
|
yield result
|
||||||
|
@ -1164,7 +1164,7 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
|
||||||
|
|
||||||
for line in parsed_list:
|
for line in parsed_list:
|
||||||
match = re.findall(r"^(.+?)_\d+\.wav$", line.split("|")[0])
|
match = re.findall(r"^(.+?)_\d+\.wav$", line.split("|")[0])
|
||||||
print(match)
|
|
||||||
if match is None or len(match) == 0:
|
if match is None or len(match) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
@ -559,7 +559,7 @@ def setup_gradio():
|
||||||
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
|
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
|
training_keep_x_past_checkpoints = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
|
||||||
training_gpu_count = gr.Number(label="GPUs", value=get_device_count())
|
training_gpu_count = gr.Number(label="GPUs", value=get_device_count())
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
start_training_button = gr.Button(value="Train")
|
start_training_button = gr.Button(value="Train")
|
||||||
|
@ -777,7 +777,7 @@ def setup_gradio():
|
||||||
training_configs,
|
training_configs,
|
||||||
verbose_training,
|
verbose_training,
|
||||||
training_gpu_count,
|
training_gpu_count,
|
||||||
training_keep_x_past_datasets,
|
training_keep_x_past_checkpoints,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
training_output,
|
training_output,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user