diff --git a/src/utils.py b/src/utils.py index 1bcb74e..41e8d16 100755 --- a/src/utils.py +++ b/src/utils.py @@ -27,7 +27,6 @@ import music_tag import gradio as gr import gradio.utils import pandas as pd -import whisperx from datetime import datetime from datetime import timedelta @@ -234,7 +233,7 @@ def generate( if emotion == "Custom": if prompt and prompt.strip() != "": cut_text = f"[{prompt},] {cut_text}" - else: + elif emotion != "None": cut_text = f"[I am really {emotion.lower()},] {cut_text}" progress.msg_prefix = f'[{str(line+1)}/{str(len(texts))}]' @@ -465,14 +464,21 @@ def update_baseline_for_latents_chunks( voice ): return 1 files = os.listdir(path) + + total = 0 total_duration = 0 + for file in files: if file[-4:] != ".wav": continue + metadata = torchaudio.info(f'{path}/{file}') duration = metadata.num_channels * metadata.num_frames / metadata.sample_rate total_duration += duration + total = total + 1 + if args.autocalculate_voice_chunk_duration_size == 0: + return int(total_duration / total) if total > 0 else 1 return int(total_duration / args.autocalculate_voice_chunk_duration_size) if total_duration > 0 else 1 def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)): @@ -551,6 +557,8 @@ class TrainingState(): self.eta = "?" self.eta_hhmmss = "?" + self.nan_detected = False + self.last_info_check_at = 0 self.statistics = [] self.losses = [] @@ -702,13 +710,10 @@ class TrainingState(): info_line = line.split("INFO:")[-1] # to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point if ': nan' in info_line: - should_return = True - - print("! NAN DETECTED !") - self.buffer.append("! NAN DETECTED !") + self.nan_detected = True # easily rip out our stats... - match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', info_line) + match = re.findall(r'\b([a-z_0-9]+?)\b: *?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', info_line) if match and len(match) > 0: for k, v in match: self.info[k] = float(v.replace(",", "")) @@ -863,6 +868,8 @@ class TrainingState(): self.metrics['loss'] = ", ".join(self.metrics['loss']) message = f"[{self.metrics['step']}] [{self.metrics['rate']}] [ETA: {eta_hhmmss}]\n[{self.metrics['loss']}]" + if self.nan_detected: + message = f"[!NaN DETECTED!] {message}" if message: percent = self.it / float(self.its) # self.epoch / float(self.epochs) @@ -966,7 +973,6 @@ def stop_training(): try: children = [p.info for p in psutil.process_iter(attrs=['pid', 'name', 'cmdline']) if './src/train.py' in p.info['cmdline']] except Exception as e: - print(e) pass training_state.process.stdout.close() @@ -996,58 +1002,66 @@ def convert_to_halfp(): torch.save(model, outfile) print(f'Converted model to half precision: {outfile}') -# -def prepare_dataset(files, outdir, language=None, progress=None): +def whisper_transcribe( file, language=None ): + # shouldn't happen, but it's for safety + if not whisper_model: + load_whisper_model(language=language) + + if not args.whisper_cpp: + if not language: + language = None + + return whisper_model.transcribe(file, language=language) + + res = whisper_model.transcribe(file) + segments = whisper_model.extract_text_and_timestamps( res ) + + result = { + 'segments': [] + } + for segment in segments: + reparsed = { + 'start': segment[0] / 100.0, + 'end': segment[1] / 100.0, + 'text': segment[2], + } + result['segments'].append(reparsed) + + return result + + +def prepare_dataset( files, outdir, language=None, progress=None ): unload_tts() global whisper_model - - device = "cuda" #add cpu option? - - #original whisper https://github.com/openai/whisper - #whisperx fork https://github.com/m-bain/whisperX - #supports en, fr, de, es, it, ja, zh, nl, uk, pt - - #tiny, base, small, medium, large, large-v2 - whisper_model = whisperx.load_model("medium", device) - #some additional model features require huggingface token + if whisper_model is None: + load_whisper_model(language=language) os.makedirs(outdir, exist_ok=True) - idx = 0 results = {} transcription = [] for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress): - print(f"Transcribing file: {file}") - - result = whisper_model.transcribe(file) - - print(result["segments"]) # before alignment - - # load alignment model and metadata - model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device) - - # align whisper output - result_aligned = whisperx.align(result["segments"], model_a, metadata, file, device) - - print(result_aligned["segments"]) # after alignment - print(result_aligned["word_segments"]) # after alignment - - results[os.path.basename(file)] = result - + basename = os.path.basename(file) + result = whisper_transcribe(file, language=language) + results[basename] = result print(f"Transcribed file: {file}, {len(result['segments'])} found.") waveform, sampling_rate = torchaudio.load(file) num_channels, num_frames = waveform.shape - for segment in result[ - 'segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress): + idx = 0 + for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress): start = int(segment['start'] * sampling_rate) end = int(segment['end'] * sampling_rate) sliced_waveform = waveform[:, start:end] - sliced_name = f"{pad(idx, 4)}.wav" + sliced_name = basename.replace(".wav", f"_{pad(idx, 4)}.wav") + + if not torch.any(sliced_waveform < 0): + print(f"Error with {sliced_name}, skipping...") + continue torchaudio.save(f"{outdir}/{sliced_name}", sliced_waveform, sampling_rate) @@ -1056,16 +1070,17 @@ def prepare_dataset(files, outdir, language=None, progress=None): transcription.append(line) with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f: f.write(f'{line}\n') - + with open(f'{outdir}/whisper.json', 'w', encoding="utf-8") as f: f.write(json.dumps(results, indent='\t')) - + + joined = '\n'.join(transcription) with open(f'{outdir}/train.txt', 'w', encoding="utf-8") as f: - f.write("\n".join(transcription)) + f.write(joined) unload_whisper() - return f"Processed dataset to: {outdir}" + return f"Processed dataset to: {outdir}\n{joined}" def calc_iterations( epochs, lines, batch_size ): iterations = int(epochs * lines / float(batch_size)) @@ -1411,7 +1426,7 @@ def setup_args(): 'prune-nonfinal-outputs': True, 'use-bigvgan-vocoder': True, 'concurrency-count': 2, - 'autocalculate-voice-chunk-duration-size': 10, + 'autocalculate-voice-chunk-duration-size': 0, 'output-sample-rate': 44100, 'output-volume': 1, @@ -1750,6 +1765,34 @@ def unload_voicefixer(): do_gc() +def load_whisper_model(language=None, model_name=None, progress=None): + global whisper_model + + if not model_name: + model_name = args.whisper_model + else: + args.whisper_model = model_name + save_args_settings() + + if language and f'{model_name}.{language}' in WHISPER_SPECIALIZED_MODELS: + model_name = f'{model_name}.{language}' + print(f"Loading specialized model for language: {language}") + + notify_progress(f"Loading Whisper model: {model_name}", progress) + + if args.whisper_cpp: + from whispercpp import Whisper + if not language: + language = 'auto' + + b_lang = language.encode('ascii') + whisper_model = Whisper(model_name, models_dir='./models/', language=b_lang) + else: + import whisper + whisper_model = whisper.load_model(model_name) + + print("Loaded Whisper model") + def unload_whisper(): global whisper_model