Update 'src/utils.py'

This commit is contained in:
yqxtqymn 2023-03-06 00:28:34 +00:00
parent 2101131cfb
commit e45ea6b26a

View File

@ -27,7 +27,6 @@ import music_tag
import gradio as gr import gradio as gr
import gradio.utils import gradio.utils
import pandas as pd import pandas as pd
import whisperx
from datetime import datetime from datetime import datetime
from datetime import timedelta from datetime import timedelta
@ -234,7 +233,7 @@ def generate(
if emotion == "Custom": if emotion == "Custom":
if prompt and prompt.strip() != "": if prompt and prompt.strip() != "":
cut_text = f"[{prompt},] {cut_text}" cut_text = f"[{prompt},] {cut_text}"
else: elif emotion != "None":
cut_text = f"[I am really {emotion.lower()},] {cut_text}" cut_text = f"[I am really {emotion.lower()},] {cut_text}"
progress.msg_prefix = f'[{str(line+1)}/{str(len(texts))}]' progress.msg_prefix = f'[{str(line+1)}/{str(len(texts))}]'
@ -465,14 +464,21 @@ def update_baseline_for_latents_chunks( voice ):
return 1 return 1
files = os.listdir(path) files = os.listdir(path)
total = 0
total_duration = 0 total_duration = 0
for file in files: for file in files:
if file[-4:] != ".wav": if file[-4:] != ".wav":
continue continue
metadata = torchaudio.info(f'{path}/{file}') metadata = torchaudio.info(f'{path}/{file}')
duration = metadata.num_channels * metadata.num_frames / metadata.sample_rate duration = metadata.num_channels * metadata.num_frames / metadata.sample_rate
total_duration += duration total_duration += duration
total = total + 1
if args.autocalculate_voice_chunk_duration_size == 0:
return int(total_duration / total) if total > 0 else 1
return int(total_duration / args.autocalculate_voice_chunk_duration_size) if total_duration > 0 else 1 return int(total_duration / args.autocalculate_voice_chunk_duration_size) if total_duration > 0 else 1
def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)): def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
@ -551,6 +557,8 @@ class TrainingState():
self.eta = "?" self.eta = "?"
self.eta_hhmmss = "?" self.eta_hhmmss = "?"
self.nan_detected = False
self.last_info_check_at = 0 self.last_info_check_at = 0
self.statistics = [] self.statistics = []
self.losses = [] self.losses = []
@ -702,13 +710,10 @@ class TrainingState():
info_line = line.split("INFO:")[-1] info_line = line.split("INFO:")[-1]
# to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point # to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point
if ': nan' in info_line: if ': nan' in info_line:
should_return = True self.nan_detected = True
print("! NAN DETECTED !")
self.buffer.append("! NAN DETECTED !")
# easily rip out our stats... # easily rip out our stats...
match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', info_line) match = re.findall(r'\b([a-z_0-9]+?)\b: *?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', info_line)
if match and len(match) > 0: if match and len(match) > 0:
for k, v in match: for k, v in match:
self.info[k] = float(v.replace(",", "")) self.info[k] = float(v.replace(",", ""))
@ -863,6 +868,8 @@ class TrainingState():
self.metrics['loss'] = ", ".join(self.metrics['loss']) self.metrics['loss'] = ", ".join(self.metrics['loss'])
message = f"[{self.metrics['step']}] [{self.metrics['rate']}] [ETA: {eta_hhmmss}]\n[{self.metrics['loss']}]" message = f"[{self.metrics['step']}] [{self.metrics['rate']}] [ETA: {eta_hhmmss}]\n[{self.metrics['loss']}]"
if self.nan_detected:
message = f"[!NaN DETECTED!] {message}"
if message: if message:
percent = self.it / float(self.its) # self.epoch / float(self.epochs) percent = self.it / float(self.its) # self.epoch / float(self.epochs)
@ -966,7 +973,6 @@ def stop_training():
try: try:
children = [p.info for p in psutil.process_iter(attrs=['pid', 'name', 'cmdline']) if './src/train.py' in p.info['cmdline']] children = [p.info for p in psutil.process_iter(attrs=['pid', 'name', 'cmdline']) if './src/train.py' in p.info['cmdline']]
except Exception as e: except Exception as e:
print(e)
pass pass
training_state.process.stdout.close() training_state.process.stdout.close()
@ -996,58 +1002,66 @@ def convert_to_halfp():
torch.save(model, outfile) torch.save(model, outfile)
print(f'Converted model to half precision: {outfile}') print(f'Converted model to half precision: {outfile}')
# def whisper_transcribe( file, language=None ):
def prepare_dataset(files, outdir, language=None, progress=None): # shouldn't happen, but it's for safety
if not whisper_model:
load_whisper_model(language=language)
if not args.whisper_cpp:
if not language:
language = None
return whisper_model.transcribe(file, language=language)
res = whisper_model.transcribe(file)
segments = whisper_model.extract_text_and_timestamps( res )
result = {
'segments': []
}
for segment in segments:
reparsed = {
'start': segment[0] / 100.0,
'end': segment[1] / 100.0,
'text': segment[2],
}
result['segments'].append(reparsed)
return result
def prepare_dataset( files, outdir, language=None, progress=None ):
unload_tts() unload_tts()
global whisper_model global whisper_model
if whisper_model is None:
device = "cuda" #add cpu option? load_whisper_model(language=language)
#original whisper https://github.com/openai/whisper
#whisperx fork https://github.com/m-bain/whisperX
#supports en, fr, de, es, it, ja, zh, nl, uk, pt
#tiny, base, small, medium, large, large-v2
whisper_model = whisperx.load_model("medium", device)
#some additional model features require huggingface token
os.makedirs(outdir, exist_ok=True) os.makedirs(outdir, exist_ok=True)
idx = 0
results = {} results = {}
transcription = [] transcription = []
for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress): for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
print(f"Transcribing file: {file}") basename = os.path.basename(file)
result = whisper_transcribe(file, language=language)
result = whisper_model.transcribe(file) results[basename] = result
print(result["segments"]) # before alignment
# load alignment model and metadata
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
# align whisper output
result_aligned = whisperx.align(result["segments"], model_a, metadata, file, device)
print(result_aligned["segments"]) # after alignment
print(result_aligned["word_segments"]) # after alignment
results[os.path.basename(file)] = result
print(f"Transcribed file: {file}, {len(result['segments'])} found.") print(f"Transcribed file: {file}, {len(result['segments'])} found.")
waveform, sampling_rate = torchaudio.load(file) waveform, sampling_rate = torchaudio.load(file)
num_channels, num_frames = waveform.shape num_channels, num_frames = waveform.shape
for segment in result[ idx = 0
'segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress): for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
start = int(segment['start'] * sampling_rate) start = int(segment['start'] * sampling_rate)
end = int(segment['end'] * sampling_rate) end = int(segment['end'] * sampling_rate)
sliced_waveform = waveform[:, start:end] sliced_waveform = waveform[:, start:end]
sliced_name = f"{pad(idx, 4)}.wav" sliced_name = basename.replace(".wav", f"_{pad(idx, 4)}.wav")
if not torch.any(sliced_waveform < 0):
print(f"Error with {sliced_name}, skipping...")
continue
torchaudio.save(f"{outdir}/{sliced_name}", sliced_waveform, sampling_rate) torchaudio.save(f"{outdir}/{sliced_name}", sliced_waveform, sampling_rate)
@ -1056,16 +1070,17 @@ def prepare_dataset(files, outdir, language=None, progress=None):
transcription.append(line) transcription.append(line)
with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f: with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f:
f.write(f'{line}\n') f.write(f'{line}\n')
with open(f'{outdir}/whisper.json', 'w', encoding="utf-8") as f: with open(f'{outdir}/whisper.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(results, indent='\t')) f.write(json.dumps(results, indent='\t'))
joined = '\n'.join(transcription)
with open(f'{outdir}/train.txt', 'w', encoding="utf-8") as f: with open(f'{outdir}/train.txt', 'w', encoding="utf-8") as f:
f.write("\n".join(transcription)) f.write(joined)
unload_whisper() unload_whisper()
return f"Processed dataset to: {outdir}" return f"Processed dataset to: {outdir}\n{joined}"
def calc_iterations( epochs, lines, batch_size ): def calc_iterations( epochs, lines, batch_size ):
iterations = int(epochs * lines / float(batch_size)) iterations = int(epochs * lines / float(batch_size))
@ -1411,7 +1426,7 @@ def setup_args():
'prune-nonfinal-outputs': True, 'prune-nonfinal-outputs': True,
'use-bigvgan-vocoder': True, 'use-bigvgan-vocoder': True,
'concurrency-count': 2, 'concurrency-count': 2,
'autocalculate-voice-chunk-duration-size': 10, 'autocalculate-voice-chunk-duration-size': 0,
'output-sample-rate': 44100, 'output-sample-rate': 44100,
'output-volume': 1, 'output-volume': 1,
@ -1750,6 +1765,34 @@ def unload_voicefixer():
do_gc() do_gc()
def load_whisper_model(language=None, model_name=None, progress=None):
global whisper_model
if not model_name:
model_name = args.whisper_model
else:
args.whisper_model = model_name
save_args_settings()
if language and f'{model_name}.{language}' in WHISPER_SPECIALIZED_MODELS:
model_name = f'{model_name}.{language}'
print(f"Loading specialized model for language: {language}")
notify_progress(f"Loading Whisper model: {model_name}", progress)
if args.whisper_cpp:
from whispercpp import Whisper
if not language:
language = 'auto'
b_lang = language.encode('ascii')
whisper_model = Whisper(model_name, models_dir='./models/', language=b_lang)
else:
import whisper
whisper_model = whisper.load_model(model_name)
print("Loaded Whisper model")
def unload_whisper(): def unload_whisper():
global whisper_model global whisper_model