Update 'src/utils.py'

This commit is contained in:
yqxtqymn 2023-03-06 00:28:34 +00:00
parent 2101131cfb
commit e45ea6b26a

View File

@ -27,7 +27,6 @@ import music_tag
import gradio as gr
import gradio.utils
import pandas as pd
import whisperx
from datetime import datetime
from datetime import timedelta
@ -234,7 +233,7 @@ def generate(
if emotion == "Custom":
if prompt and prompt.strip() != "":
cut_text = f"[{prompt},] {cut_text}"
else:
elif emotion != "None":
cut_text = f"[I am really {emotion.lower()},] {cut_text}"
progress.msg_prefix = f'[{str(line+1)}/{str(len(texts))}]'
@ -465,14 +464,21 @@ def update_baseline_for_latents_chunks( voice ):
return 1
files = os.listdir(path)
total = 0
total_duration = 0
for file in files:
if file[-4:] != ".wav":
continue
metadata = torchaudio.info(f'{path}/{file}')
duration = metadata.num_channels * metadata.num_frames / metadata.sample_rate
total_duration += duration
total = total + 1
if args.autocalculate_voice_chunk_duration_size == 0:
return int(total_duration / total) if total > 0 else 1
return int(total_duration / args.autocalculate_voice_chunk_duration_size) if total_duration > 0 else 1
def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
@ -551,6 +557,8 @@ class TrainingState():
self.eta = "?"
self.eta_hhmmss = "?"
self.nan_detected = False
self.last_info_check_at = 0
self.statistics = []
self.losses = []
@ -702,13 +710,10 @@ class TrainingState():
info_line = line.split("INFO:")[-1]
# to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point
if ': nan' in info_line:
should_return = True
print("! NAN DETECTED !")
self.buffer.append("! NAN DETECTED !")
self.nan_detected = True
# easily rip out our stats...
match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', info_line)
match = re.findall(r'\b([a-z_0-9]+?)\b: *?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', info_line)
if match and len(match) > 0:
for k, v in match:
self.info[k] = float(v.replace(",", ""))
@ -863,6 +868,8 @@ class TrainingState():
self.metrics['loss'] = ", ".join(self.metrics['loss'])
message = f"[{self.metrics['step']}] [{self.metrics['rate']}] [ETA: {eta_hhmmss}]\n[{self.metrics['loss']}]"
if self.nan_detected:
message = f"[!NaN DETECTED!] {message}"
if message:
percent = self.it / float(self.its) # self.epoch / float(self.epochs)
@ -966,7 +973,6 @@ def stop_training():
try:
children = [p.info for p in psutil.process_iter(attrs=['pid', 'name', 'cmdline']) if './src/train.py' in p.info['cmdline']]
except Exception as e:
print(e)
pass
training_state.process.stdout.close()
@ -996,58 +1002,66 @@ def convert_to_halfp():
torch.save(model, outfile)
print(f'Converted model to half precision: {outfile}')
#
def prepare_dataset(files, outdir, language=None, progress=None):
def whisper_transcribe( file, language=None ):
# shouldn't happen, but it's for safety
if not whisper_model:
load_whisper_model(language=language)
if not args.whisper_cpp:
if not language:
language = None
return whisper_model.transcribe(file, language=language)
res = whisper_model.transcribe(file)
segments = whisper_model.extract_text_and_timestamps( res )
result = {
'segments': []
}
for segment in segments:
reparsed = {
'start': segment[0] / 100.0,
'end': segment[1] / 100.0,
'text': segment[2],
}
result['segments'].append(reparsed)
return result
def prepare_dataset( files, outdir, language=None, progress=None ):
unload_tts()
global whisper_model
device = "cuda" #add cpu option?
#original whisper https://github.com/openai/whisper
#whisperx fork https://github.com/m-bain/whisperX
#supports en, fr, de, es, it, ja, zh, nl, uk, pt
#tiny, base, small, medium, large, large-v2
whisper_model = whisperx.load_model("medium", device)
#some additional model features require huggingface token
if whisper_model is None:
load_whisper_model(language=language)
os.makedirs(outdir, exist_ok=True)
idx = 0
results = {}
transcription = []
for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
print(f"Transcribing file: {file}")
result = whisper_model.transcribe(file)
print(result["segments"]) # before alignment
# load alignment model and metadata
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
# align whisper output
result_aligned = whisperx.align(result["segments"], model_a, metadata, file, device)
print(result_aligned["segments"]) # after alignment
print(result_aligned["word_segments"]) # after alignment
results[os.path.basename(file)] = result
basename = os.path.basename(file)
result = whisper_transcribe(file, language=language)
results[basename] = result
print(f"Transcribed file: {file}, {len(result['segments'])} found.")
waveform, sampling_rate = torchaudio.load(file)
num_channels, num_frames = waveform.shape
for segment in result[
'segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
idx = 0
for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
start = int(segment['start'] * sampling_rate)
end = int(segment['end'] * sampling_rate)
sliced_waveform = waveform[:, start:end]
sliced_name = f"{pad(idx, 4)}.wav"
sliced_name = basename.replace(".wav", f"_{pad(idx, 4)}.wav")
if not torch.any(sliced_waveform < 0):
print(f"Error with {sliced_name}, skipping...")
continue
torchaudio.save(f"{outdir}/{sliced_name}", sliced_waveform, sampling_rate)
@ -1056,16 +1070,17 @@ def prepare_dataset(files, outdir, language=None, progress=None):
transcription.append(line)
with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f:
f.write(f'{line}\n')
with open(f'{outdir}/whisper.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(results, indent='\t'))
joined = '\n'.join(transcription)
with open(f'{outdir}/train.txt', 'w', encoding="utf-8") as f:
f.write("\n".join(transcription))
f.write(joined)
unload_whisper()
return f"Processed dataset to: {outdir}"
return f"Processed dataset to: {outdir}\n{joined}"
def calc_iterations( epochs, lines, batch_size ):
iterations = int(epochs * lines / float(batch_size))
@ -1411,7 +1426,7 @@ def setup_args():
'prune-nonfinal-outputs': True,
'use-bigvgan-vocoder': True,
'concurrency-count': 2,
'autocalculate-voice-chunk-duration-size': 10,
'autocalculate-voice-chunk-duration-size': 0,
'output-sample-rate': 44100,
'output-volume': 1,
@ -1750,6 +1765,34 @@ def unload_voicefixer():
do_gc()
def load_whisper_model(language=None, model_name=None, progress=None):
global whisper_model
if not model_name:
model_name = args.whisper_model
else:
args.whisper_model = model_name
save_args_settings()
if language and f'{model_name}.{language}' in WHISPER_SPECIALIZED_MODELS:
model_name = f'{model_name}.{language}'
print(f"Loading specialized model for language: {language}")
notify_progress(f"Loading Whisper model: {model_name}", progress)
if args.whisper_cpp:
from whispercpp import Whisper
if not language:
language = 'auto'
b_lang = language.encode('ascii')
whisper_model = Whisper(model_name, models_dir='./models/', language=b_lang)
else:
import whisper
whisper_model = whisper.load_model(model_name)
print("Loaded Whisper model")
def unload_whisper():
global whisper_model