a bunch of shit i had uncommited over the past while pertaining to VALL-E
This commit is contained in:
parent
b785192dfc
commit
d8b996911c
|
@ -1 +1 @@
|
|||
Subproject commit 0bcdf81d0444218b4dedaefa5c546d42f36b8130
|
||||
Subproject commit f025470d60fd18993caaa651e6faa585bcc420f0
|
48
src/utils.py
48
src/utils.py
|
@ -75,6 +75,7 @@ try:
|
|||
|
||||
VALLE_ENABLED = True
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
|
||||
if VALLE_ENABLED:
|
||||
|
@ -156,10 +157,12 @@ def generate_valle(**kwargs):
|
|||
|
||||
voice_cache = {}
|
||||
def fetch_voice( voice ):
|
||||
voice_dir = f'./training/{voice}/audio/'
|
||||
if not os.path.isdir(voice_dir):
|
||||
voice_dir = f'./voices/{voice}/'
|
||||
files = [ f'{voice_dir}/{d}' for d in os.listdir(voice_dir) if d[-4:] == ".wav" ]
|
||||
return files
|
||||
# return random.choice(files)
|
||||
# return files
|
||||
return random.choice(files)
|
||||
|
||||
def get_settings( override=None ):
|
||||
settings = {
|
||||
|
@ -1089,13 +1092,13 @@ class TrainingState():
|
|||
'ar-quarter.lr', 'nar-quarter.lr',
|
||||
]
|
||||
keys['losses'] = [
|
||||
'ar.loss', 'nar.loss',
|
||||
'ar-half.loss', 'nar-half.loss',
|
||||
'ar-quarter.loss', 'nar-quarter.loss',
|
||||
'ar.loss', 'nar.loss', 'ar+nar.loss',
|
||||
'ar-half.loss', 'nar-half.loss', 'ar-half+nar-half.loss',
|
||||
'ar-quarter.loss', 'nar-quarter.loss', 'ar-quarter+nar-quarter.loss',
|
||||
|
||||
'ar.loss.nll', 'nar.loss.nll',
|
||||
'ar-half.loss.nll', 'nar-half.loss.nll',
|
||||
'ar-quarter.loss.nll', 'nar-quarter.loss.nll',
|
||||
# 'ar.loss.nll', 'nar.loss.nll',
|
||||
# 'ar-half.loss.nll', 'nar-half.loss.nll',
|
||||
# 'ar-quarter.loss.nll', 'nar-quarter.loss.nll',
|
||||
]
|
||||
|
||||
keys['accuracies'] = [
|
||||
|
@ -1123,7 +1126,7 @@ class TrainingState():
|
|||
|
||||
prefix = ""
|
||||
|
||||
if data["mode"] == "validation":
|
||||
if "mode" in self.info and self.info["mode"] == "validation":
|
||||
prefix = f'{self.info["name"] if "name" in self.info else "val"}_'
|
||||
|
||||
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{prefix}{k}' })
|
||||
|
@ -1231,6 +1234,7 @@ class TrainingState():
|
|||
|
||||
unq = {}
|
||||
averager = None
|
||||
prev_state = 0
|
||||
|
||||
for log in logs:
|
||||
with open(log, 'r', encoding="utf-8") as f:
|
||||
|
@ -1250,6 +1254,7 @@ class TrainingState():
|
|||
|
||||
name = "train"
|
||||
mode = "training"
|
||||
prev_state = 0
|
||||
elif line.find('Validation Metrics:') >= 0:
|
||||
data = json.loads(line.split("Validation Metrics:")[-1])
|
||||
if "it" not in data:
|
||||
|
@ -1257,8 +1262,15 @@ class TrainingState():
|
|||
if "epoch" not in data:
|
||||
data['epoch'] = epoch
|
||||
|
||||
name = data['name'] if 'name' in data else "val"
|
||||
# name = data['name'] if 'name' in data else "val"
|
||||
mode = "validation"
|
||||
|
||||
if prev_state == 0:
|
||||
name = "subtrain"
|
||||
else:
|
||||
name = "val"
|
||||
|
||||
prev_state += 1
|
||||
else:
|
||||
continue
|
||||
|
||||
|
@ -1272,6 +1284,7 @@ class TrainingState():
|
|||
if not averager or averager['key'] != f'{it}_{name}' or averager['mode'] != mode:
|
||||
averager = {
|
||||
'key': f'{it}_{name}',
|
||||
'name': name,
|
||||
'mode': mode,
|
||||
"metrics": {}
|
||||
}
|
||||
|
@ -1292,11 +1305,13 @@ class TrainingState():
|
|||
if update and it <= self.last_info_check_at:
|
||||
continue
|
||||
|
||||
blacklist = [ "batch", "eval" ]
|
||||
for it in unq:
|
||||
if args.tts_backend == "vall-e":
|
||||
stats = unq[it]
|
||||
data = {k: sum(v) / len(v) for k, v in stats['metrics'].items()}
|
||||
data['mode'] = stats
|
||||
data = {k: sum(v) / len(v) for k, v in stats['metrics'].items() if k not in blacklist }
|
||||
data['name'] = stats['name']
|
||||
data['mode'] = stats['mode']
|
||||
data['steps'] = len(stats['metrics']['it'])
|
||||
else:
|
||||
data = unq[it]
|
||||
|
@ -1633,6 +1648,7 @@ def whisper_transcribe( file, language=None ):
|
|||
|
||||
device = "cuda" if get_device_name() == "cuda" else "cpu"
|
||||
if whisper_vad:
|
||||
# omits a considerable amount of the end
|
||||
"""
|
||||
if args.whisper_batchsize > 1:
|
||||
result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize, language=language, task="transcribe")
|
||||
|
@ -1778,7 +1794,9 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul
|
|||
messages = []
|
||||
|
||||
if not os.path.exists(infile):
|
||||
raise Exception(f"Missing dataset: {infile}")
|
||||
message = f"Missing dataset: {infile}"
|
||||
print(message)
|
||||
return message
|
||||
|
||||
if results is None:
|
||||
results = json.load(open(infile, 'r', encoding="utf-8"))
|
||||
|
@ -1903,7 +1921,9 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
|||
indir = f'./training/{voice}/'
|
||||
infile = f'{indir}/whisper.json'
|
||||
if not os.path.exists(infile):
|
||||
raise Exception(f"Missing dataset: {infile}")
|
||||
message = f"Missing dataset: {infile}"
|
||||
print(message)
|
||||
return message
|
||||
|
||||
results = json.load(open(infile, 'r', encoding="utf-8"))
|
||||
|
||||
|
|
63
src/webui.py
63
src/webui.py
|
@ -196,6 +196,50 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
|
|||
def slice_dataset_proxy( voice, trim_silence, start_offset, end_offset, progress=gr.Progress(track_tqdm=True) ):
|
||||
return slice_dataset( voice, trim_silence=trim_silence, start_offset=start_offset, end_offset=end_offset, results=None, progress=progress )
|
||||
|
||||
def diarize_dataset( voice, progress=gr.Progress(track_tqdm=False) ):
|
||||
from pyannote.audio import Pipeline
|
||||
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=args.hf_token)
|
||||
|
||||
messages = []
|
||||
files = sorted( get_voices(load_latents=False)[voice] )
|
||||
for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
|
||||
diarization = pipeline(file)
|
||||
for turn, _, speaker in diarization.itertracks(yield_label=True):
|
||||
message = f"start={turn.start:.1f}s stop={turn.end:.1f}s speaker_{speaker}"
|
||||
print(message)
|
||||
messages.append(message)
|
||||
|
||||
return "\n".join(messages)
|
||||
|
||||
def prepare_all_datasets( language, validation_text_length, validation_audio_length, skip_existings, slice_audio, trim_silence, slice_start_offset, slice_end_offset, progress=gr.Progress(track_tqdm=False) ):
|
||||
kwargs = locals()
|
||||
|
||||
messages = []
|
||||
voices = get_voice_list()
|
||||
|
||||
"""
|
||||
for voice in voices:
|
||||
message = prepare_dataset_proxy(voice, **kwargs)
|
||||
messages.append(message)
|
||||
"""
|
||||
for voice in voices:
|
||||
print("Processing:", voice)
|
||||
message = transcribe_dataset( voice=voice, language=language, skip_existings=skip_existings, progress=progress )
|
||||
messages.append(message)
|
||||
|
||||
if slice_audio:
|
||||
for voice in voices:
|
||||
print("Processing:", voice)
|
||||
message = slice_dataset( voice, trim_silence=trim_silence, start_offset=slice_start_offset, end_offset=slice_end_offset, results=None, progress=progress )
|
||||
messages.append(message)
|
||||
|
||||
for voice in voices:
|
||||
print("Processing:", voice)
|
||||
message = prepare_dataset( voice, use_segments=slice_audio, text_length=validation_text_length, audio_length=validation_audio_length, progress=progress )
|
||||
messages.append(message)
|
||||
|
||||
return "\n".join(messages)
|
||||
|
||||
def prepare_dataset_proxy( voice, language, validation_text_length, validation_audio_length, skip_existings, slice_audio, trim_silence, slice_start_offset, slice_end_offset, progress=gr.Progress(track_tqdm=False) ):
|
||||
messages = []
|
||||
|
||||
|
@ -468,6 +512,8 @@ def setup_gradio():
|
|||
DATASET_SETTINGS['slice_end_offset'] = gr.Number(label="Slice End Offset", value=0)
|
||||
|
||||
transcribe_button = gr.Button(value="Transcribe and Process")
|
||||
transcribe_all_button = gr.Button(value="Transcribe All")
|
||||
diarize_button = gr.Button(value="Diarize")
|
||||
|
||||
with gr.Row():
|
||||
slice_dataset_button = gr.Button(value="(Re)Slice Audio")
|
||||
|
@ -579,7 +625,7 @@ def setup_gradio():
|
|||
tooltip=['epoch', 'it', 'value', 'type'],
|
||||
width=500,
|
||||
height=350,
|
||||
visible=args.tts_backend=="vall-e"
|
||||
visible=False, # args.tts_backend=="vall-e"
|
||||
)
|
||||
view_losses = gr.Button(value="View Losses")
|
||||
|
||||
|
@ -611,10 +657,7 @@ def setup_gradio():
|
|||
# EXEC_SETTINGS['tts_backend'] = gr.Dropdown(TTSES, label="TTS Backend", value=args.tts_backend if args.tts_backend else TTSES[0])
|
||||
|
||||
with gr.Column(visible=args.tts_backend=="vall-e"):
|
||||
default_valle_model_choice = ""
|
||||
if len(valle_models):
|
||||
default_valle_model_choice = valle_models[0]
|
||||
EXEC_SETTINGS['valle_model'] = gr.Dropdown(choices=valle_models, label="VALL-E Model Config", value=args.valle_model if args.valle_model else default_valle_model_choice)
|
||||
EXEC_SETTINGS['valle_model'] = gr.Dropdown(choices=valle_models, label="VALL-E Model Config", value=args.valle_model if args.valle_model else valle_models[0])
|
||||
|
||||
with gr.Column(visible=args.tts_backend=="tortoise"):
|
||||
EXEC_SETTINGS['autoregressive_model'] = gr.Dropdown(choices=["auto"] + autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else "auto")
|
||||
|
@ -859,6 +902,16 @@ def setup_gradio():
|
|||
inputs=dataset_settings,
|
||||
outputs=prepare_dataset_output #console_output
|
||||
)
|
||||
transcribe_all_button.click(
|
||||
prepare_all_datasets,
|
||||
inputs=dataset_settings[1:],
|
||||
outputs=prepare_dataset_output #console_output
|
||||
)
|
||||
diarize_button.click(
|
||||
diarize_dataset,
|
||||
inputs=dataset_settings[0],
|
||||
outputs=prepare_dataset_output #console_output
|
||||
)
|
||||
prepare_dataset_button.click(
|
||||
prepare_dataset,
|
||||
inputs=[
|
||||
|
|
Loading…
Reference in New Issue
Block a user