added helper script to cull short enough lines from training set as a validation set (if it yields good results doing validation during training, i'll add it to the web ui)
This commit is contained in:
parent
7f89e8058a
commit
fe8bf7a9d1
35
src/cull_dataset.py
Executable file
35
src/cull_dataset.py
Executable file
|
@ -0,0 +1,35 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
indir = f'./training/{sys.argv[1]}/'
|
||||
cap = int(sys.argv[2])
|
||||
|
||||
if not os.path.isdir(indir):
|
||||
raise Exception(f"Invalid directory: {indir}")
|
||||
|
||||
if not os.path.exists(f'{indir}/train.txt'):
|
||||
raise Exception(f"Missing dataset: {indir}/train.txt")
|
||||
|
||||
with open(f'{indir}/train.txt', 'r', encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
validation = []
|
||||
training = []
|
||||
|
||||
for line in lines:
|
||||
split = line.split("|")
|
||||
filename = split[0]
|
||||
text = split[1]
|
||||
|
||||
if len(text) < cap:
|
||||
validation.append(line.strip())
|
||||
else:
|
||||
training.append(line.strip())
|
||||
|
||||
with open(f'{indir}/train_culled.txt', 'w', encoding="utf-8") as f:
|
||||
f.write("\n".join(training))
|
||||
|
||||
with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f:
|
||||
f.write("\n".join(validation))
|
||||
|
||||
print(f"Culled {len(validation)} lines")
|
18
src/utils.py
18
src/utils.py
|
@ -605,7 +605,7 @@ def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, prog
|
|||
|
||||
# superfluous, but it cleans up some things
|
||||
class TrainingState():
|
||||
def __init__(self, config_path, keep_x_past_datasets=0, start=True, gpus=1):
|
||||
def __init__(self, config_path, keep_x_past_checkpoints=0, start=True, gpus=1):
|
||||
# parse config to get its iteration
|
||||
with open(config_path, 'r') as file:
|
||||
self.config = yaml.safe_load(file)
|
||||
|
@ -664,8 +664,8 @@ class TrainingState():
|
|||
self.loss_milestones = [ 1.0, 0.15, 0.05 ]
|
||||
|
||||
self.load_losses()
|
||||
if keep_x_past_datasets > 0:
|
||||
self.cleanup_old(keep=keep_x_past_datasets)
|
||||
if keep_x_past_checkpoints > 0:
|
||||
self.cleanup_old(keep=keep_x_past_checkpoints)
|
||||
if start:
|
||||
self.spawn_process(config_path=config_path, gpus=gpus)
|
||||
|
||||
|
@ -772,7 +772,7 @@ class TrainingState():
|
|||
print("Removing", path)
|
||||
os.remove(path)
|
||||
|
||||
def parse(self, line, verbose=False, keep_x_past_datasets=0, buffer_size=8, progress=None ):
|
||||
def parse(self, line, verbose=False, keep_x_past_checkpoints=0, buffer_size=8, progress=None ):
|
||||
self.buffer.append(f'{line}')
|
||||
|
||||
should_return = False
|
||||
|
@ -830,7 +830,7 @@ class TrainingState():
|
|||
print(f'{"{:.3f}".format(percent*100)}% {message}')
|
||||
self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}')
|
||||
|
||||
self.cleanup_old(keep=keep_x_past_datasets)
|
||||
self.cleanup_old(keep=keep_x_past_checkpoints)
|
||||
|
||||
if line.find('%|') > 0:
|
||||
match = re.findall(r'(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line)
|
||||
|
@ -986,7 +986,7 @@ class TrainingState():
|
|||
message,
|
||||
)
|
||||
|
||||
def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)):
|
||||
def run_training(config_path, verbose=False, gpus=1, keep_x_past_checkpoints=0, progress=gr.Progress(track_tqdm=True)):
|
||||
global training_state
|
||||
if training_state and training_state.process:
|
||||
return "Training already in progress"
|
||||
|
@ -1008,13 +1008,13 @@ def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, pro
|
|||
unload_whisper()
|
||||
unload_voicefixer()
|
||||
|
||||
training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets, gpus=gpus)
|
||||
training_state = TrainingState(config_path=config_path, keep_x_past_checkpoints=keep_x_past_checkpoints, gpus=gpus)
|
||||
|
||||
for line in iter(training_state.process.stdout.readline, ""):
|
||||
if training_state.killed:
|
||||
return
|
||||
|
||||
result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_datasets=keep_x_past_datasets, progress=progress )
|
||||
result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_checkpoints=keep_x_past_checkpoints, progress=progress )
|
||||
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
|
||||
if result:
|
||||
yield result
|
||||
|
@ -1164,7 +1164,7 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
|
|||
|
||||
for line in parsed_list:
|
||||
match = re.findall(r"^(.+?)_\d+\.wav$", line.split("|")[0])
|
||||
print(match)
|
||||
|
||||
if match is None or len(match) == 0:
|
||||
continue
|
||||
|
||||
|
|
|
@ -559,7 +559,7 @@ def setup_gradio():
|
|||
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
|
||||
|
||||
with gr.Row():
|
||||
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
|
||||
training_keep_x_past_checkpoints = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
|
||||
training_gpu_count = gr.Number(label="GPUs", value=get_device_count())
|
||||
with gr.Row():
|
||||
start_training_button = gr.Button(value="Train")
|
||||
|
@ -777,7 +777,7 @@ def setup_gradio():
|
|||
training_configs,
|
||||
verbose_training,
|
||||
training_gpu_count,
|
||||
training_keep_x_past_datasets,
|
||||
training_keep_x_past_checkpoints,
|
||||
],
|
||||
outputs=[
|
||||
training_output,
|
||||
|
|
Loading…
Reference in New Issue
Block a user