a bunch of shit i had uncommited over the past while pertaining to VALL-E

This commit is contained in:
mrq 2023-04-12 20:02:46 +00:00
parent b785192dfc
commit d8b996911c
3 changed files with 94 additions and 21 deletions

@ -1 +1 @@
Subproject commit 0bcdf81d0444218b4dedaefa5c546d42f36b8130 Subproject commit f025470d60fd18993caaa651e6faa585bcc420f0

View File

@ -75,6 +75,7 @@ try:
VALLE_ENABLED = True VALLE_ENABLED = True
except Exception as e: except Exception as e:
print(e)
pass pass
if VALLE_ENABLED: if VALLE_ENABLED:
@ -156,10 +157,12 @@ def generate_valle(**kwargs):
voice_cache = {} voice_cache = {}
def fetch_voice( voice ): def fetch_voice( voice ):
voice_dir = f'./training/{voice}/audio/'
if not os.path.isdir(voice_dir):
voice_dir = f'./voices/{voice}/' voice_dir = f'./voices/{voice}/'
files = [ f'{voice_dir}/{d}' for d in os.listdir(voice_dir) if d[-4:] == ".wav" ] files = [ f'{voice_dir}/{d}' for d in os.listdir(voice_dir) if d[-4:] == ".wav" ]
return files # return files
# return random.choice(files) return random.choice(files)
def get_settings( override=None ): def get_settings( override=None ):
settings = { settings = {
@ -1089,13 +1092,13 @@ class TrainingState():
'ar-quarter.lr', 'nar-quarter.lr', 'ar-quarter.lr', 'nar-quarter.lr',
] ]
keys['losses'] = [ keys['losses'] = [
'ar.loss', 'nar.loss', 'ar.loss', 'nar.loss', 'ar+nar.loss',
'ar-half.loss', 'nar-half.loss', 'ar-half.loss', 'nar-half.loss', 'ar-half+nar-half.loss',
'ar-quarter.loss', 'nar-quarter.loss', 'ar-quarter.loss', 'nar-quarter.loss', 'ar-quarter+nar-quarter.loss',
'ar.loss.nll', 'nar.loss.nll', # 'ar.loss.nll', 'nar.loss.nll',
'ar-half.loss.nll', 'nar-half.loss.nll', # 'ar-half.loss.nll', 'nar-half.loss.nll',
'ar-quarter.loss.nll', 'nar-quarter.loss.nll', # 'ar-quarter.loss.nll', 'nar-quarter.loss.nll',
] ]
keys['accuracies'] = [ keys['accuracies'] = [
@ -1123,7 +1126,7 @@ class TrainingState():
prefix = "" prefix = ""
if data["mode"] == "validation": if "mode" in self.info and self.info["mode"] == "validation":
prefix = f'{self.info["name"] if "name" in self.info else "val"}_' prefix = f'{self.info["name"] if "name" in self.info else "val"}_'
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{prefix}{k}' }) self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{prefix}{k}' })
@ -1231,6 +1234,7 @@ class TrainingState():
unq = {} unq = {}
averager = None averager = None
prev_state = 0
for log in logs: for log in logs:
with open(log, 'r', encoding="utf-8") as f: with open(log, 'r', encoding="utf-8") as f:
@ -1250,6 +1254,7 @@ class TrainingState():
name = "train" name = "train"
mode = "training" mode = "training"
prev_state = 0
elif line.find('Validation Metrics:') >= 0: elif line.find('Validation Metrics:') >= 0:
data = json.loads(line.split("Validation Metrics:")[-1]) data = json.loads(line.split("Validation Metrics:")[-1])
if "it" not in data: if "it" not in data:
@ -1257,8 +1262,15 @@ class TrainingState():
if "epoch" not in data: if "epoch" not in data:
data['epoch'] = epoch data['epoch'] = epoch
name = data['name'] if 'name' in data else "val" # name = data['name'] if 'name' in data else "val"
mode = "validation" mode = "validation"
if prev_state == 0:
name = "subtrain"
else:
name = "val"
prev_state += 1
else: else:
continue continue
@ -1272,6 +1284,7 @@ class TrainingState():
if not averager or averager['key'] != f'{it}_{name}' or averager['mode'] != mode: if not averager or averager['key'] != f'{it}_{name}' or averager['mode'] != mode:
averager = { averager = {
'key': f'{it}_{name}', 'key': f'{it}_{name}',
'name': name,
'mode': mode, 'mode': mode,
"metrics": {} "metrics": {}
} }
@ -1292,11 +1305,13 @@ class TrainingState():
if update and it <= self.last_info_check_at: if update and it <= self.last_info_check_at:
continue continue
blacklist = [ "batch", "eval" ]
for it in unq: for it in unq:
if args.tts_backend == "vall-e": if args.tts_backend == "vall-e":
stats = unq[it] stats = unq[it]
data = {k: sum(v) / len(v) for k, v in stats['metrics'].items()} data = {k: sum(v) / len(v) for k, v in stats['metrics'].items() if k not in blacklist }
data['mode'] = stats data['name'] = stats['name']
data['mode'] = stats['mode']
data['steps'] = len(stats['metrics']['it']) data['steps'] = len(stats['metrics']['it'])
else: else:
data = unq[it] data = unq[it]
@ -1633,6 +1648,7 @@ def whisper_transcribe( file, language=None ):
device = "cuda" if get_device_name() == "cuda" else "cpu" device = "cuda" if get_device_name() == "cuda" else "cpu"
if whisper_vad: if whisper_vad:
# omits a considerable amount of the end
""" """
if args.whisper_batchsize > 1: if args.whisper_batchsize > 1:
result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize, language=language, task="transcribe") result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize, language=language, task="transcribe")
@ -1778,7 +1794,9 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul
messages = [] messages = []
if not os.path.exists(infile): if not os.path.exists(infile):
raise Exception(f"Missing dataset: {infile}") message = f"Missing dataset: {infile}"
print(message)
return message
if results is None: if results is None:
results = json.load(open(infile, 'r', encoding="utf-8")) results = json.load(open(infile, 'r', encoding="utf-8"))
@ -1903,7 +1921,9 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
indir = f'./training/{voice}/' indir = f'./training/{voice}/'
infile = f'{indir}/whisper.json' infile = f'{indir}/whisper.json'
if not os.path.exists(infile): if not os.path.exists(infile):
raise Exception(f"Missing dataset: {infile}") message = f"Missing dataset: {infile}"
print(message)
return message
results = json.load(open(infile, 'r', encoding="utf-8")) results = json.load(open(infile, 'r', encoding="utf-8"))

View File

@ -196,6 +196,50 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
def slice_dataset_proxy( voice, trim_silence, start_offset, end_offset, progress=gr.Progress(track_tqdm=True) ): def slice_dataset_proxy( voice, trim_silence, start_offset, end_offset, progress=gr.Progress(track_tqdm=True) ):
return slice_dataset( voice, trim_silence=trim_silence, start_offset=start_offset, end_offset=end_offset, results=None, progress=progress ) return slice_dataset( voice, trim_silence=trim_silence, start_offset=start_offset, end_offset=end_offset, results=None, progress=progress )
def diarize_dataset( voice, progress=gr.Progress(track_tqdm=False) ):
from pyannote.audio import Pipeline
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=args.hf_token)
messages = []
files = sorted( get_voices(load_latents=False)[voice] )
for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
diarization = pipeline(file)
for turn, _, speaker in diarization.itertracks(yield_label=True):
message = f"start={turn.start:.1f}s stop={turn.end:.1f}s speaker_{speaker}"
print(message)
messages.append(message)
return "\n".join(messages)
def prepare_all_datasets( language, validation_text_length, validation_audio_length, skip_existings, slice_audio, trim_silence, slice_start_offset, slice_end_offset, progress=gr.Progress(track_tqdm=False) ):
kwargs = locals()
messages = []
voices = get_voice_list()
"""
for voice in voices:
message = prepare_dataset_proxy(voice, **kwargs)
messages.append(message)
"""
for voice in voices:
print("Processing:", voice)
message = transcribe_dataset( voice=voice, language=language, skip_existings=skip_existings, progress=progress )
messages.append(message)
if slice_audio:
for voice in voices:
print("Processing:", voice)
message = slice_dataset( voice, trim_silence=trim_silence, start_offset=slice_start_offset, end_offset=slice_end_offset, results=None, progress=progress )
messages.append(message)
for voice in voices:
print("Processing:", voice)
message = prepare_dataset( voice, use_segments=slice_audio, text_length=validation_text_length, audio_length=validation_audio_length, progress=progress )
messages.append(message)
return "\n".join(messages)
def prepare_dataset_proxy( voice, language, validation_text_length, validation_audio_length, skip_existings, slice_audio, trim_silence, slice_start_offset, slice_end_offset, progress=gr.Progress(track_tqdm=False) ): def prepare_dataset_proxy( voice, language, validation_text_length, validation_audio_length, skip_existings, slice_audio, trim_silence, slice_start_offset, slice_end_offset, progress=gr.Progress(track_tqdm=False) ):
messages = [] messages = []
@ -468,6 +512,8 @@ def setup_gradio():
DATASET_SETTINGS['slice_end_offset'] = gr.Number(label="Slice End Offset", value=0) DATASET_SETTINGS['slice_end_offset'] = gr.Number(label="Slice End Offset", value=0)
transcribe_button = gr.Button(value="Transcribe and Process") transcribe_button = gr.Button(value="Transcribe and Process")
transcribe_all_button = gr.Button(value="Transcribe All")
diarize_button = gr.Button(value="Diarize")
with gr.Row(): with gr.Row():
slice_dataset_button = gr.Button(value="(Re)Slice Audio") slice_dataset_button = gr.Button(value="(Re)Slice Audio")
@ -579,7 +625,7 @@ def setup_gradio():
tooltip=['epoch', 'it', 'value', 'type'], tooltip=['epoch', 'it', 'value', 'type'],
width=500, width=500,
height=350, height=350,
visible=args.tts_backend=="vall-e" visible=False, # args.tts_backend=="vall-e"
) )
view_losses = gr.Button(value="View Losses") view_losses = gr.Button(value="View Losses")
@ -611,10 +657,7 @@ def setup_gradio():
# EXEC_SETTINGS['tts_backend'] = gr.Dropdown(TTSES, label="TTS Backend", value=args.tts_backend if args.tts_backend else TTSES[0]) # EXEC_SETTINGS['tts_backend'] = gr.Dropdown(TTSES, label="TTS Backend", value=args.tts_backend if args.tts_backend else TTSES[0])
with gr.Column(visible=args.tts_backend=="vall-e"): with gr.Column(visible=args.tts_backend=="vall-e"):
default_valle_model_choice = "" EXEC_SETTINGS['valle_model'] = gr.Dropdown(choices=valle_models, label="VALL-E Model Config", value=args.valle_model if args.valle_model else valle_models[0])
if len(valle_models):
default_valle_model_choice = valle_models[0]
EXEC_SETTINGS['valle_model'] = gr.Dropdown(choices=valle_models, label="VALL-E Model Config", value=args.valle_model if args.valle_model else default_valle_model_choice)
with gr.Column(visible=args.tts_backend=="tortoise"): with gr.Column(visible=args.tts_backend=="tortoise"):
EXEC_SETTINGS['autoregressive_model'] = gr.Dropdown(choices=["auto"] + autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else "auto") EXEC_SETTINGS['autoregressive_model'] = gr.Dropdown(choices=["auto"] + autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else "auto")
@ -859,6 +902,16 @@ def setup_gradio():
inputs=dataset_settings, inputs=dataset_settings,
outputs=prepare_dataset_output #console_output outputs=prepare_dataset_output #console_output
) )
transcribe_all_button.click(
prepare_all_datasets,
inputs=dataset_settings[1:],
outputs=prepare_dataset_output #console_output
)
diarize_button.click(
diarize_dataset,
inputs=dataset_settings[0],
outputs=prepare_dataset_output #console_output
)
prepare_dataset_button.click( prepare_dataset_button.click(
prepare_dataset, prepare_dataset,
inputs=[ inputs=[