a bunch of shit i had uncommited over the past while pertaining to VALL-E

remotes/1715375133522516288/master
mrq 2023-04-12 20:02:46 +07:00
parent b785192dfc
commit d8b996911c
3 changed files with 94 additions and 21 deletions

@ -1 +1 @@
Subproject commit 0bcdf81d0444218b4dedaefa5c546d42f36b8130
Subproject commit f025470d60fd18993caaa651e6faa585bcc420f0

@ -75,6 +75,7 @@ try:
VALLE_ENABLED = True
except Exception as e:
print(e)
pass
if VALLE_ENABLED:
@ -156,10 +157,12 @@ def generate_valle(**kwargs):
voice_cache = {}
def fetch_voice( voice ):
voice_dir = f'./voices/{voice}/'
voice_dir = f'./training/{voice}/audio/'
if not os.path.isdir(voice_dir):
voice_dir = f'./voices/{voice}/'
files = [ f'{voice_dir}/{d}' for d in os.listdir(voice_dir) if d[-4:] == ".wav" ]
return files
# return random.choice(files)
# return files
return random.choice(files)
def get_settings( override=None ):
settings = {
@ -1089,13 +1092,13 @@ class TrainingState():
'ar-quarter.lr', 'nar-quarter.lr',
]
keys['losses'] = [
'ar.loss', 'nar.loss',
'ar-half.loss', 'nar-half.loss',
'ar-quarter.loss', 'nar-quarter.loss',
'ar.loss', 'nar.loss', 'ar+nar.loss',
'ar-half.loss', 'nar-half.loss', 'ar-half+nar-half.loss',
'ar-quarter.loss', 'nar-quarter.loss', 'ar-quarter+nar-quarter.loss',
'ar.loss.nll', 'nar.loss.nll',
'ar-half.loss.nll', 'nar-half.loss.nll',
'ar-quarter.loss.nll', 'nar-quarter.loss.nll',
# 'ar.loss.nll', 'nar.loss.nll',
# 'ar-half.loss.nll', 'nar-half.loss.nll',
# 'ar-quarter.loss.nll', 'nar-quarter.loss.nll',
]
keys['accuracies'] = [
@ -1123,7 +1126,7 @@ class TrainingState():
prefix = ""
if data["mode"] == "validation":
if "mode" in self.info and self.info["mode"] == "validation":
prefix = f'{self.info["name"] if "name" in self.info else "val"}_'
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{prefix}{k}' })
@ -1231,6 +1234,7 @@ class TrainingState():
unq = {}
averager = None
prev_state = 0
for log in logs:
with open(log, 'r', encoding="utf-8") as f:
@ -1250,6 +1254,7 @@ class TrainingState():
name = "train"
mode = "training"
prev_state = 0
elif line.find('Validation Metrics:') >= 0:
data = json.loads(line.split("Validation Metrics:")[-1])
if "it" not in data:
@ -1257,8 +1262,15 @@ class TrainingState():
if "epoch" not in data:
data['epoch'] = epoch
name = data['name'] if 'name' in data else "val"
# name = data['name'] if 'name' in data else "val"
mode = "validation"
if prev_state == 0:
name = "subtrain"
else:
name = "val"
prev_state += 1
else:
continue
@ -1272,6 +1284,7 @@ class TrainingState():
if not averager or averager['key'] != f'{it}_{name}' or averager['mode'] != mode:
averager = {
'key': f'{it}_{name}',
'name': name,
'mode': mode,
"metrics": {}
}
@ -1292,11 +1305,13 @@ class TrainingState():
if update and it <= self.last_info_check_at:
continue
blacklist = [ "batch", "eval" ]
for it in unq:
if args.tts_backend == "vall-e":
stats = unq[it]
data = {k: sum(v) / len(v) for k, v in stats['metrics'].items()}
data['mode'] = stats
data = {k: sum(v) / len(v) for k, v in stats['metrics'].items() if k not in blacklist }
data['name'] = stats['name']
data['mode'] = stats['mode']
data['steps'] = len(stats['metrics']['it'])
else:
data = unq[it]
@ -1633,6 +1648,7 @@ def whisper_transcribe( file, language=None ):
device = "cuda" if get_device_name() == "cuda" else "cpu"
if whisper_vad:
# omits a considerable amount of the end
"""
if args.whisper_batchsize > 1:
result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize, language=language, task="transcribe")
@ -1778,7 +1794,9 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul
messages = []
if not os.path.exists(infile):
raise Exception(f"Missing dataset: {infile}")
message = f"Missing dataset: {infile}"
print(message)
return message
if results is None:
results = json.load(open(infile, 'r', encoding="utf-8"))
@ -1903,7 +1921,9 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
indir = f'./training/{voice}/'
infile = f'{indir}/whisper.json'
if not os.path.exists(infile):
raise Exception(f"Missing dataset: {infile}")
message = f"Missing dataset: {infile}"
print(message)
return message
results = json.load(open(infile, 'r', encoding="utf-8"))

@ -196,6 +196,50 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
def slice_dataset_proxy( voice, trim_silence, start_offset, end_offset, progress=gr.Progress(track_tqdm=True) ):
return slice_dataset( voice, trim_silence=trim_silence, start_offset=start_offset, end_offset=end_offset, results=None, progress=progress )
def diarize_dataset( voice, progress=gr.Progress(track_tqdm=False) ):
from pyannote.audio import Pipeline
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=args.hf_token)
messages = []
files = sorted( get_voices(load_latents=False)[voice] )
for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
diarization = pipeline(file)
for turn, _, speaker in diarization.itertracks(yield_label=True):
message = f"start={turn.start:.1f}s stop={turn.end:.1f}s speaker_{speaker}"
print(message)
messages.append(message)
return "\n".join(messages)
def prepare_all_datasets( language, validation_text_length, validation_audio_length, skip_existings, slice_audio, trim_silence, slice_start_offset, slice_end_offset, progress=gr.Progress(track_tqdm=False) ):
kwargs = locals()
messages = []
voices = get_voice_list()
"""
for voice in voices:
message = prepare_dataset_proxy(voice, **kwargs)
messages.append(message)
"""
for voice in voices:
print("Processing:", voice)
message = transcribe_dataset( voice=voice, language=language, skip_existings=skip_existings, progress=progress )
messages.append(message)
if slice_audio:
for voice in voices:
print("Processing:", voice)
message = slice_dataset( voice, trim_silence=trim_silence, start_offset=slice_start_offset, end_offset=slice_end_offset, results=None, progress=progress )
messages.append(message)
for voice in voices:
print("Processing:", voice)
message = prepare_dataset( voice, use_segments=slice_audio, text_length=validation_text_length, audio_length=validation_audio_length, progress=progress )
messages.append(message)
return "\n".join(messages)
def prepare_dataset_proxy( voice, language, validation_text_length, validation_audio_length, skip_existings, slice_audio, trim_silence, slice_start_offset, slice_end_offset, progress=gr.Progress(track_tqdm=False) ):
messages = []
@ -468,6 +512,8 @@ def setup_gradio():
DATASET_SETTINGS['slice_end_offset'] = gr.Number(label="Slice End Offset", value=0)
transcribe_button = gr.Button(value="Transcribe and Process")
transcribe_all_button = gr.Button(value="Transcribe All")
diarize_button = gr.Button(value="Diarize")
with gr.Row():
slice_dataset_button = gr.Button(value="(Re)Slice Audio")
@ -579,7 +625,7 @@ def setup_gradio():
tooltip=['epoch', 'it', 'value', 'type'],
width=500,
height=350,
visible=args.tts_backend=="vall-e"
visible=False, # args.tts_backend=="vall-e"
)
view_losses = gr.Button(value="View Losses")
@ -611,10 +657,7 @@ def setup_gradio():
# EXEC_SETTINGS['tts_backend'] = gr.Dropdown(TTSES, label="TTS Backend", value=args.tts_backend if args.tts_backend else TTSES[0])
with gr.Column(visible=args.tts_backend=="vall-e"):
default_valle_model_choice = ""
if len(valle_models):
default_valle_model_choice = valle_models[0]
EXEC_SETTINGS['valle_model'] = gr.Dropdown(choices=valle_models, label="VALL-E Model Config", value=args.valle_model if args.valle_model else default_valle_model_choice)
EXEC_SETTINGS['valle_model'] = gr.Dropdown(choices=valle_models, label="VALL-E Model Config", value=args.valle_model if args.valle_model else valle_models[0])
with gr.Column(visible=args.tts_backend=="tortoise"):
EXEC_SETTINGS['autoregressive_model'] = gr.Dropdown(choices=["auto"] + autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else "auto")
@ -859,6 +902,16 @@ def setup_gradio():
inputs=dataset_settings,
outputs=prepare_dataset_output #console_output
)
transcribe_all_button.click(
prepare_all_datasets,
inputs=dataset_settings[1:],
outputs=prepare_dataset_output #console_output
)
diarize_button.click(
diarize_dataset,
inputs=dataset_settings[0],
outputs=prepare_dataset_output #console_output
)
prepare_dataset_button.click(
prepare_dataset,
inputs=[