a bunch of shit i had uncommited over the past while pertaining to VALL-E
This commit is contained in:
parent
b785192dfc
commit
d8b996911c
|
@ -1 +1 @@
|
||||||
Subproject commit 0bcdf81d0444218b4dedaefa5c546d42f36b8130
|
Subproject commit f025470d60fd18993caaa651e6faa585bcc420f0
|
50
src/utils.py
50
src/utils.py
|
@ -75,6 +75,7 @@ try:
|
||||||
|
|
||||||
VALLE_ENABLED = True
|
VALLE_ENABLED = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if VALLE_ENABLED:
|
if VALLE_ENABLED:
|
||||||
|
@ -156,10 +157,12 @@ def generate_valle(**kwargs):
|
||||||
|
|
||||||
voice_cache = {}
|
voice_cache = {}
|
||||||
def fetch_voice( voice ):
|
def fetch_voice( voice ):
|
||||||
voice_dir = f'./voices/{voice}/'
|
voice_dir = f'./training/{voice}/audio/'
|
||||||
|
if not os.path.isdir(voice_dir):
|
||||||
|
voice_dir = f'./voices/{voice}/'
|
||||||
files = [ f'{voice_dir}/{d}' for d in os.listdir(voice_dir) if d[-4:] == ".wav" ]
|
files = [ f'{voice_dir}/{d}' for d in os.listdir(voice_dir) if d[-4:] == ".wav" ]
|
||||||
return files
|
# return files
|
||||||
# return random.choice(files)
|
return random.choice(files)
|
||||||
|
|
||||||
def get_settings( override=None ):
|
def get_settings( override=None ):
|
||||||
settings = {
|
settings = {
|
||||||
|
@ -1089,13 +1092,13 @@ class TrainingState():
|
||||||
'ar-quarter.lr', 'nar-quarter.lr',
|
'ar-quarter.lr', 'nar-quarter.lr',
|
||||||
]
|
]
|
||||||
keys['losses'] = [
|
keys['losses'] = [
|
||||||
'ar.loss', 'nar.loss',
|
'ar.loss', 'nar.loss', 'ar+nar.loss',
|
||||||
'ar-half.loss', 'nar-half.loss',
|
'ar-half.loss', 'nar-half.loss', 'ar-half+nar-half.loss',
|
||||||
'ar-quarter.loss', 'nar-quarter.loss',
|
'ar-quarter.loss', 'nar-quarter.loss', 'ar-quarter+nar-quarter.loss',
|
||||||
|
|
||||||
'ar.loss.nll', 'nar.loss.nll',
|
# 'ar.loss.nll', 'nar.loss.nll',
|
||||||
'ar-half.loss.nll', 'nar-half.loss.nll',
|
# 'ar-half.loss.nll', 'nar-half.loss.nll',
|
||||||
'ar-quarter.loss.nll', 'nar-quarter.loss.nll',
|
# 'ar-quarter.loss.nll', 'nar-quarter.loss.nll',
|
||||||
]
|
]
|
||||||
|
|
||||||
keys['accuracies'] = [
|
keys['accuracies'] = [
|
||||||
|
@ -1123,7 +1126,7 @@ class TrainingState():
|
||||||
|
|
||||||
prefix = ""
|
prefix = ""
|
||||||
|
|
||||||
if data["mode"] == "validation":
|
if "mode" in self.info and self.info["mode"] == "validation":
|
||||||
prefix = f'{self.info["name"] if "name" in self.info else "val"}_'
|
prefix = f'{self.info["name"] if "name" in self.info else "val"}_'
|
||||||
|
|
||||||
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{prefix}{k}' })
|
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{prefix}{k}' })
|
||||||
|
@ -1231,6 +1234,7 @@ class TrainingState():
|
||||||
|
|
||||||
unq = {}
|
unq = {}
|
||||||
averager = None
|
averager = None
|
||||||
|
prev_state = 0
|
||||||
|
|
||||||
for log in logs:
|
for log in logs:
|
||||||
with open(log, 'r', encoding="utf-8") as f:
|
with open(log, 'r', encoding="utf-8") as f:
|
||||||
|
@ -1250,6 +1254,7 @@ class TrainingState():
|
||||||
|
|
||||||
name = "train"
|
name = "train"
|
||||||
mode = "training"
|
mode = "training"
|
||||||
|
prev_state = 0
|
||||||
elif line.find('Validation Metrics:') >= 0:
|
elif line.find('Validation Metrics:') >= 0:
|
||||||
data = json.loads(line.split("Validation Metrics:")[-1])
|
data = json.loads(line.split("Validation Metrics:")[-1])
|
||||||
if "it" not in data:
|
if "it" not in data:
|
||||||
|
@ -1257,8 +1262,15 @@ class TrainingState():
|
||||||
if "epoch" not in data:
|
if "epoch" not in data:
|
||||||
data['epoch'] = epoch
|
data['epoch'] = epoch
|
||||||
|
|
||||||
name = data['name'] if 'name' in data else "val"
|
# name = data['name'] if 'name' in data else "val"
|
||||||
mode = "validation"
|
mode = "validation"
|
||||||
|
|
||||||
|
if prev_state == 0:
|
||||||
|
name = "subtrain"
|
||||||
|
else:
|
||||||
|
name = "val"
|
||||||
|
|
||||||
|
prev_state += 1
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -1272,6 +1284,7 @@ class TrainingState():
|
||||||
if not averager or averager['key'] != f'{it}_{name}' or averager['mode'] != mode:
|
if not averager or averager['key'] != f'{it}_{name}' or averager['mode'] != mode:
|
||||||
averager = {
|
averager = {
|
||||||
'key': f'{it}_{name}',
|
'key': f'{it}_{name}',
|
||||||
|
'name': name,
|
||||||
'mode': mode,
|
'mode': mode,
|
||||||
"metrics": {}
|
"metrics": {}
|
||||||
}
|
}
|
||||||
|
@ -1292,11 +1305,13 @@ class TrainingState():
|
||||||
if update and it <= self.last_info_check_at:
|
if update and it <= self.last_info_check_at:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
blacklist = [ "batch", "eval" ]
|
||||||
for it in unq:
|
for it in unq:
|
||||||
if args.tts_backend == "vall-e":
|
if args.tts_backend == "vall-e":
|
||||||
stats = unq[it]
|
stats = unq[it]
|
||||||
data = {k: sum(v) / len(v) for k, v in stats['metrics'].items()}
|
data = {k: sum(v) / len(v) for k, v in stats['metrics'].items() if k not in blacklist }
|
||||||
data['mode'] = stats
|
data['name'] = stats['name']
|
||||||
|
data['mode'] = stats['mode']
|
||||||
data['steps'] = len(stats['metrics']['it'])
|
data['steps'] = len(stats['metrics']['it'])
|
||||||
else:
|
else:
|
||||||
data = unq[it]
|
data = unq[it]
|
||||||
|
@ -1633,6 +1648,7 @@ def whisper_transcribe( file, language=None ):
|
||||||
|
|
||||||
device = "cuda" if get_device_name() == "cuda" else "cpu"
|
device = "cuda" if get_device_name() == "cuda" else "cpu"
|
||||||
if whisper_vad:
|
if whisper_vad:
|
||||||
|
# omits a considerable amount of the end
|
||||||
"""
|
"""
|
||||||
if args.whisper_batchsize > 1:
|
if args.whisper_batchsize > 1:
|
||||||
result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize, language=language, task="transcribe")
|
result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize, language=language, task="transcribe")
|
||||||
|
@ -1778,7 +1794,9 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul
|
||||||
messages = []
|
messages = []
|
||||||
|
|
||||||
if not os.path.exists(infile):
|
if not os.path.exists(infile):
|
||||||
raise Exception(f"Missing dataset: {infile}")
|
message = f"Missing dataset: {infile}"
|
||||||
|
print(message)
|
||||||
|
return message
|
||||||
|
|
||||||
if results is None:
|
if results is None:
|
||||||
results = json.load(open(infile, 'r', encoding="utf-8"))
|
results = json.load(open(infile, 'r', encoding="utf-8"))
|
||||||
|
@ -1903,7 +1921,9 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
||||||
indir = f'./training/{voice}/'
|
indir = f'./training/{voice}/'
|
||||||
infile = f'{indir}/whisper.json'
|
infile = f'{indir}/whisper.json'
|
||||||
if not os.path.exists(infile):
|
if not os.path.exists(infile):
|
||||||
raise Exception(f"Missing dataset: {infile}")
|
message = f"Missing dataset: {infile}"
|
||||||
|
print(message)
|
||||||
|
return message
|
||||||
|
|
||||||
results = json.load(open(infile, 'r', encoding="utf-8"))
|
results = json.load(open(infile, 'r', encoding="utf-8"))
|
||||||
|
|
||||||
|
|
63
src/webui.py
63
src/webui.py
|
@ -196,6 +196,50 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
|
||||||
def slice_dataset_proxy( voice, trim_silence, start_offset, end_offset, progress=gr.Progress(track_tqdm=True) ):
|
def slice_dataset_proxy( voice, trim_silence, start_offset, end_offset, progress=gr.Progress(track_tqdm=True) ):
|
||||||
return slice_dataset( voice, trim_silence=trim_silence, start_offset=start_offset, end_offset=end_offset, results=None, progress=progress )
|
return slice_dataset( voice, trim_silence=trim_silence, start_offset=start_offset, end_offset=end_offset, results=None, progress=progress )
|
||||||
|
|
||||||
|
def diarize_dataset( voice, progress=gr.Progress(track_tqdm=False) ):
|
||||||
|
from pyannote.audio import Pipeline
|
||||||
|
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=args.hf_token)
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
files = sorted( get_voices(load_latents=False)[voice] )
|
||||||
|
for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
|
||||||
|
diarization = pipeline(file)
|
||||||
|
for turn, _, speaker in diarization.itertracks(yield_label=True):
|
||||||
|
message = f"start={turn.start:.1f}s stop={turn.end:.1f}s speaker_{speaker}"
|
||||||
|
print(message)
|
||||||
|
messages.append(message)
|
||||||
|
|
||||||
|
return "\n".join(messages)
|
||||||
|
|
||||||
|
def prepare_all_datasets( language, validation_text_length, validation_audio_length, skip_existings, slice_audio, trim_silence, slice_start_offset, slice_end_offset, progress=gr.Progress(track_tqdm=False) ):
|
||||||
|
kwargs = locals()
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
voices = get_voice_list()
|
||||||
|
|
||||||
|
"""
|
||||||
|
for voice in voices:
|
||||||
|
message = prepare_dataset_proxy(voice, **kwargs)
|
||||||
|
messages.append(message)
|
||||||
|
"""
|
||||||
|
for voice in voices:
|
||||||
|
print("Processing:", voice)
|
||||||
|
message = transcribe_dataset( voice=voice, language=language, skip_existings=skip_existings, progress=progress )
|
||||||
|
messages.append(message)
|
||||||
|
|
||||||
|
if slice_audio:
|
||||||
|
for voice in voices:
|
||||||
|
print("Processing:", voice)
|
||||||
|
message = slice_dataset( voice, trim_silence=trim_silence, start_offset=slice_start_offset, end_offset=slice_end_offset, results=None, progress=progress )
|
||||||
|
messages.append(message)
|
||||||
|
|
||||||
|
for voice in voices:
|
||||||
|
print("Processing:", voice)
|
||||||
|
message = prepare_dataset( voice, use_segments=slice_audio, text_length=validation_text_length, audio_length=validation_audio_length, progress=progress )
|
||||||
|
messages.append(message)
|
||||||
|
|
||||||
|
return "\n".join(messages)
|
||||||
|
|
||||||
def prepare_dataset_proxy( voice, language, validation_text_length, validation_audio_length, skip_existings, slice_audio, trim_silence, slice_start_offset, slice_end_offset, progress=gr.Progress(track_tqdm=False) ):
|
def prepare_dataset_proxy( voice, language, validation_text_length, validation_audio_length, skip_existings, slice_audio, trim_silence, slice_start_offset, slice_end_offset, progress=gr.Progress(track_tqdm=False) ):
|
||||||
messages = []
|
messages = []
|
||||||
|
|
||||||
|
@ -468,6 +512,8 @@ def setup_gradio():
|
||||||
DATASET_SETTINGS['slice_end_offset'] = gr.Number(label="Slice End Offset", value=0)
|
DATASET_SETTINGS['slice_end_offset'] = gr.Number(label="Slice End Offset", value=0)
|
||||||
|
|
||||||
transcribe_button = gr.Button(value="Transcribe and Process")
|
transcribe_button = gr.Button(value="Transcribe and Process")
|
||||||
|
transcribe_all_button = gr.Button(value="Transcribe All")
|
||||||
|
diarize_button = gr.Button(value="Diarize")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
slice_dataset_button = gr.Button(value="(Re)Slice Audio")
|
slice_dataset_button = gr.Button(value="(Re)Slice Audio")
|
||||||
|
@ -579,7 +625,7 @@ def setup_gradio():
|
||||||
tooltip=['epoch', 'it', 'value', 'type'],
|
tooltip=['epoch', 'it', 'value', 'type'],
|
||||||
width=500,
|
width=500,
|
||||||
height=350,
|
height=350,
|
||||||
visible=args.tts_backend=="vall-e"
|
visible=False, # args.tts_backend=="vall-e"
|
||||||
)
|
)
|
||||||
view_losses = gr.Button(value="View Losses")
|
view_losses = gr.Button(value="View Losses")
|
||||||
|
|
||||||
|
@ -611,10 +657,7 @@ def setup_gradio():
|
||||||
# EXEC_SETTINGS['tts_backend'] = gr.Dropdown(TTSES, label="TTS Backend", value=args.tts_backend if args.tts_backend else TTSES[0])
|
# EXEC_SETTINGS['tts_backend'] = gr.Dropdown(TTSES, label="TTS Backend", value=args.tts_backend if args.tts_backend else TTSES[0])
|
||||||
|
|
||||||
with gr.Column(visible=args.tts_backend=="vall-e"):
|
with gr.Column(visible=args.tts_backend=="vall-e"):
|
||||||
default_valle_model_choice = ""
|
EXEC_SETTINGS['valle_model'] = gr.Dropdown(choices=valle_models, label="VALL-E Model Config", value=args.valle_model if args.valle_model else valle_models[0])
|
||||||
if len(valle_models):
|
|
||||||
default_valle_model_choice = valle_models[0]
|
|
||||||
EXEC_SETTINGS['valle_model'] = gr.Dropdown(choices=valle_models, label="VALL-E Model Config", value=args.valle_model if args.valle_model else default_valle_model_choice)
|
|
||||||
|
|
||||||
with gr.Column(visible=args.tts_backend=="tortoise"):
|
with gr.Column(visible=args.tts_backend=="tortoise"):
|
||||||
EXEC_SETTINGS['autoregressive_model'] = gr.Dropdown(choices=["auto"] + autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else "auto")
|
EXEC_SETTINGS['autoregressive_model'] = gr.Dropdown(choices=["auto"] + autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else "auto")
|
||||||
|
@ -859,6 +902,16 @@ def setup_gradio():
|
||||||
inputs=dataset_settings,
|
inputs=dataset_settings,
|
||||||
outputs=prepare_dataset_output #console_output
|
outputs=prepare_dataset_output #console_output
|
||||||
)
|
)
|
||||||
|
transcribe_all_button.click(
|
||||||
|
prepare_all_datasets,
|
||||||
|
inputs=dataset_settings[1:],
|
||||||
|
outputs=prepare_dataset_output #console_output
|
||||||
|
)
|
||||||
|
diarize_button.click(
|
||||||
|
diarize_dataset,
|
||||||
|
inputs=dataset_settings[0],
|
||||||
|
outputs=prepare_dataset_output #console_output
|
||||||
|
)
|
||||||
prepare_dataset_button.click(
|
prepare_dataset_button.click(
|
||||||
prepare_dataset,
|
prepare_dataset,
|
||||||
inputs=[
|
inputs=[
|
||||||
|
|
Loading…
Reference in New Issue
Block a user