Update 'src/utils.py'
This commit is contained in:
parent
2101131cfb
commit
e45ea6b26a
135
src/utils.py
135
src/utils.py
|
@ -27,7 +27,6 @@ import music_tag
|
|||
import gradio as gr
|
||||
import gradio.utils
|
||||
import pandas as pd
|
||||
import whisperx
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
|
@ -234,7 +233,7 @@ def generate(
|
|||
if emotion == "Custom":
|
||||
if prompt and prompt.strip() != "":
|
||||
cut_text = f"[{prompt},] {cut_text}"
|
||||
else:
|
||||
elif emotion != "None":
|
||||
cut_text = f"[I am really {emotion.lower()},] {cut_text}"
|
||||
|
||||
progress.msg_prefix = f'[{str(line+1)}/{str(len(texts))}]'
|
||||
|
@ -465,14 +464,21 @@ def update_baseline_for_latents_chunks( voice ):
|
|||
return 1
|
||||
|
||||
files = os.listdir(path)
|
||||
|
||||
total = 0
|
||||
total_duration = 0
|
||||
|
||||
for file in files:
|
||||
if file[-4:] != ".wav":
|
||||
continue
|
||||
|
||||
metadata = torchaudio.info(f'{path}/{file}')
|
||||
duration = metadata.num_channels * metadata.num_frames / metadata.sample_rate
|
||||
total_duration += duration
|
||||
total = total + 1
|
||||
|
||||
if args.autocalculate_voice_chunk_duration_size == 0:
|
||||
return int(total_duration / total) if total > 0 else 1
|
||||
return int(total_duration / args.autocalculate_voice_chunk_duration_size) if total_duration > 0 else 1
|
||||
|
||||
def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
|
||||
|
@ -551,6 +557,8 @@ class TrainingState():
|
|||
self.eta = "?"
|
||||
self.eta_hhmmss = "?"
|
||||
|
||||
self.nan_detected = False
|
||||
|
||||
self.last_info_check_at = 0
|
||||
self.statistics = []
|
||||
self.losses = []
|
||||
|
@ -702,13 +710,10 @@ class TrainingState():
|
|||
info_line = line.split("INFO:")[-1]
|
||||
# to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point
|
||||
if ': nan' in info_line:
|
||||
should_return = True
|
||||
|
||||
print("! NAN DETECTED !")
|
||||
self.buffer.append("! NAN DETECTED !")
|
||||
self.nan_detected = True
|
||||
|
||||
# easily rip out our stats...
|
||||
match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', info_line)
|
||||
match = re.findall(r'\b([a-z_0-9]+?)\b: *?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', info_line)
|
||||
if match and len(match) > 0:
|
||||
for k, v in match:
|
||||
self.info[k] = float(v.replace(",", ""))
|
||||
|
@ -863,6 +868,8 @@ class TrainingState():
|
|||
self.metrics['loss'] = ", ".join(self.metrics['loss'])
|
||||
|
||||
message = f"[{self.metrics['step']}] [{self.metrics['rate']}] [ETA: {eta_hhmmss}]\n[{self.metrics['loss']}]"
|
||||
if self.nan_detected:
|
||||
message = f"[!NaN DETECTED!] {message}"
|
||||
|
||||
if message:
|
||||
percent = self.it / float(self.its) # self.epoch / float(self.epochs)
|
||||
|
@ -966,7 +973,6 @@ def stop_training():
|
|||
try:
|
||||
children = [p.info for p in psutil.process_iter(attrs=['pid', 'name', 'cmdline']) if './src/train.py' in p.info['cmdline']]
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
|
||||
training_state.process.stdout.close()
|
||||
|
@ -996,58 +1002,66 @@ def convert_to_halfp():
|
|||
torch.save(model, outfile)
|
||||
print(f'Converted model to half precision: {outfile}')
|
||||
|
||||
#
|
||||
def prepare_dataset(files, outdir, language=None, progress=None):
|
||||
def whisper_transcribe( file, language=None ):
|
||||
# shouldn't happen, but it's for safety
|
||||
if not whisper_model:
|
||||
load_whisper_model(language=language)
|
||||
|
||||
if not args.whisper_cpp:
|
||||
if not language:
|
||||
language = None
|
||||
|
||||
return whisper_model.transcribe(file, language=language)
|
||||
|
||||
res = whisper_model.transcribe(file)
|
||||
segments = whisper_model.extract_text_and_timestamps( res )
|
||||
|
||||
result = {
|
||||
'segments': []
|
||||
}
|
||||
for segment in segments:
|
||||
reparsed = {
|
||||
'start': segment[0] / 100.0,
|
||||
'end': segment[1] / 100.0,
|
||||
'text': segment[2],
|
||||
}
|
||||
result['segments'].append(reparsed)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def prepare_dataset( files, outdir, language=None, progress=None ):
|
||||
unload_tts()
|
||||
|
||||
global whisper_model
|
||||
|
||||
device = "cuda" #add cpu option?
|
||||
|
||||
#original whisper https://github.com/openai/whisper
|
||||
#whisperx fork https://github.com/m-bain/whisperX
|
||||
#supports en, fr, de, es, it, ja, zh, nl, uk, pt
|
||||
|
||||
#tiny, base, small, medium, large, large-v2
|
||||
whisper_model = whisperx.load_model("medium", device)
|
||||
#some additional model features require huggingface token
|
||||
if whisper_model is None:
|
||||
load_whisper_model(language=language)
|
||||
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
|
||||
idx = 0
|
||||
results = {}
|
||||
transcription = []
|
||||
|
||||
for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
|
||||
print(f"Transcribing file: {file}")
|
||||
|
||||
result = whisper_model.transcribe(file)
|
||||
|
||||
print(result["segments"]) # before alignment
|
||||
|
||||
# load alignment model and metadata
|
||||
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
|
||||
|
||||
# align whisper output
|
||||
result_aligned = whisperx.align(result["segments"], model_a, metadata, file, device)
|
||||
|
||||
print(result_aligned["segments"]) # after alignment
|
||||
print(result_aligned["word_segments"]) # after alignment
|
||||
|
||||
results[os.path.basename(file)] = result
|
||||
|
||||
basename = os.path.basename(file)
|
||||
result = whisper_transcribe(file, language=language)
|
||||
results[basename] = result
|
||||
print(f"Transcribed file: {file}, {len(result['segments'])} found.")
|
||||
|
||||
waveform, sampling_rate = torchaudio.load(file)
|
||||
num_channels, num_frames = waveform.shape
|
||||
|
||||
for segment in result[
|
||||
'segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
|
||||
idx = 0
|
||||
for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
|
||||
start = int(segment['start'] * sampling_rate)
|
||||
end = int(segment['end'] * sampling_rate)
|
||||
|
||||
sliced_waveform = waveform[:, start:end]
|
||||
sliced_name = f"{pad(idx, 4)}.wav"
|
||||
sliced_name = basename.replace(".wav", f"_{pad(idx, 4)}.wav")
|
||||
|
||||
if not torch.any(sliced_waveform < 0):
|
||||
print(f"Error with {sliced_name}, skipping...")
|
||||
continue
|
||||
|
||||
torchaudio.save(f"{outdir}/{sliced_name}", sliced_waveform, sampling_rate)
|
||||
|
||||
|
@ -1056,16 +1070,17 @@ def prepare_dataset(files, outdir, language=None, progress=None):
|
|||
transcription.append(line)
|
||||
with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f:
|
||||
f.write(f'{line}\n')
|
||||
|
||||
|
||||
with open(f'{outdir}/whisper.json', 'w', encoding="utf-8") as f:
|
||||
f.write(json.dumps(results, indent='\t'))
|
||||
|
||||
|
||||
joined = '\n'.join(transcription)
|
||||
with open(f'{outdir}/train.txt', 'w', encoding="utf-8") as f:
|
||||
f.write("\n".join(transcription))
|
||||
f.write(joined)
|
||||
|
||||
unload_whisper()
|
||||
|
||||
return f"Processed dataset to: {outdir}"
|
||||
return f"Processed dataset to: {outdir}\n{joined}"
|
||||
|
||||
def calc_iterations( epochs, lines, batch_size ):
|
||||
iterations = int(epochs * lines / float(batch_size))
|
||||
|
@ -1411,7 +1426,7 @@ def setup_args():
|
|||
'prune-nonfinal-outputs': True,
|
||||
'use-bigvgan-vocoder': True,
|
||||
'concurrency-count': 2,
|
||||
'autocalculate-voice-chunk-duration-size': 10,
|
||||
'autocalculate-voice-chunk-duration-size': 0,
|
||||
'output-sample-rate': 44100,
|
||||
'output-volume': 1,
|
||||
|
||||
|
@ -1750,6 +1765,34 @@ def unload_voicefixer():
|
|||
|
||||
do_gc()
|
||||
|
||||
def load_whisper_model(language=None, model_name=None, progress=None):
|
||||
global whisper_model
|
||||
|
||||
if not model_name:
|
||||
model_name = args.whisper_model
|
||||
else:
|
||||
args.whisper_model = model_name
|
||||
save_args_settings()
|
||||
|
||||
if language and f'{model_name}.{language}' in WHISPER_SPECIALIZED_MODELS:
|
||||
model_name = f'{model_name}.{language}'
|
||||
print(f"Loading specialized model for language: {language}")
|
||||
|
||||
notify_progress(f"Loading Whisper model: {model_name}", progress)
|
||||
|
||||
if args.whisper_cpp:
|
||||
from whispercpp import Whisper
|
||||
if not language:
|
||||
language = 'auto'
|
||||
|
||||
b_lang = language.encode('ascii')
|
||||
whisper_model = Whisper(model_name, models_dir='./models/', language=b_lang)
|
||||
else:
|
||||
import whisper
|
||||
whisper_model = whisper.load_model(model_name)
|
||||
|
||||
print("Loaded Whisper model")
|
||||
|
||||
def unload_whisper():
|
||||
global whisper_model
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user