diff --git a/src/utils.py b/src/utils.py index a90a7b0..c5aac42 100755 --- a/src/utils.py +++ b/src/utils.py @@ -34,8 +34,6 @@ from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_ from tortoise.utils.text import split_and_recombine_text from tortoise.utils.device import get_device_name, set_device_name -import whisper - MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth" args = None @@ -46,7 +44,6 @@ voicefixer = None whisper_model = None training_state = None - def generate( text, delimiter, @@ -501,9 +498,12 @@ class TrainingState(): match = re.findall(r'iter: ([\d,]+)', line) if match and len(match) > 0: self.it = int(match[0].replace(",", "")) + + self.checkpoints = int((self.its - self.it) / self.config['logger']['save_checkpoint_freq']) else: lapsed = False + message = None if line.find('%|') > 0: match = re.findall(r'(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line) if match and len(match) > 0: @@ -516,8 +516,6 @@ class TrainingState(): until = match[5] rate = match[6] - epoch_percent = self.it / float(self.its) # self.epoch / float(self.epochs) - last_step = self.last_step self.last_step = step if last_step < step: @@ -530,10 +528,12 @@ class TrainingState(): self.it_time_delta = self.it_time_end-self.it_time_start self.it_time_start = time.time() try: - rate = f'[{"{:.3f}".format(self.it_time_delta)}s/it]' if self.it_time_delta >= 1 else f'[{"{:.3f}".format(1/self.it_time_delta)}it/s]' + rate = f'{"{:.3f}".format(self.it_time_delta)}s/it' if self.it_time_delta >= 1 else f'{"{:.3f}".format(1/self.it_time_delta)}it/s' self.it_rate = rate except Exception as e: pass + + message = f'[{self.epoch}/{self.epochs}, {self.it}/{self.its}, {step}/{steps}] [ETA: {self.eta_hhmmss}] [{self.epoch_rate}, {self.it_rate}] {self.status}' """ # I wanted frequently updated ETA, but I can't wrap my noggin around getting it to work on an empty belly @@ -550,13 +550,6 @@ class TrainingState(): pass """ - message = f'[{self.epoch}/{self.epochs}] [{self.it}/{self.its}] [ETA: {self.eta_hhmmss}] {self.epoch_rate} / {self.it_rate} {self.status}' - if progress is not None: - progress(epoch_percent, message) - - # print(f'{"{:.3f}".format(percent*100)}% {message}') - self.buffer.append(f'[{"{:.3f}".format(epoch_percent*100)}% / {"{:.3f}".format(percent*100)}%] {message}') - if lapsed: self.epoch = self.epoch + 1 self.it = int(self.epoch * (self.dataset_size / self.batch_size)) @@ -564,7 +557,7 @@ class TrainingState(): self.epoch_time_end = time.time() self.epoch_time_delta = self.epoch_time_end-self.epoch_time_start self.epoch_time_start = time.time() - self.epoch_rate = f'[{"{:.3f}".format(self.epoch_time_delta)}s/epoch]' if self.epoch_time_delta >= 1 else f'[{"{:.3f}".format(1/self.epoch_time_delta)}epoch/s]' # I doubt anyone will have it/s rates, but its here + self.epoch_rate = f'{"{:.3f}".format(self.epoch_time_delta)}s/epoch' if self.epoch_time_delta >= 1 else f'{"{:.3f}".format(1/self.epoch_time_delta)}epoch/s' # I doubt anyone will have it/s rates, but its here #self.eta = (self.epochs - self.epoch) * self.epoch_time_delta self.epoch_time_deltas = self.epoch_time_deltas + self.epoch_time_delta @@ -576,14 +569,12 @@ class TrainingState(): except Exception as e: pass - percent = self.epoch / float(self.epochs) - message = f'[{self.epoch}/{self.epochs}] [{self.it}/{self.its}] [ETA: {self.eta_hhmmss}] {self.epoch_rate} / {self.it_rate} {self.status}' - + if message: + percent = self.it / float(self.its) # self.epoch / float(self.epochs) if progress is not None: progress(percent, message) - print(f'{"{:.3f}".format(percent*100)}% {message}') - self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}') + self.buffer.append(f'[{"{:.3f}".format(percent*100)}%] {message}') if line.find('INFO: [epoch:') >= 0: # easily rip out our stats... @@ -677,12 +668,36 @@ def convert_to_halfp(): torch.save(model, outfile) print(f'Converted model to half precision: {outfile}') +def whisper_transcribe( file, language=None ): + # shouldn't happen, but it's for safety + if not whisper_model: + load_whisper_model(language=language if language else b'en') + + if not args.whisper_cpp: + return whisper_model.transcribe(file, language=language if language else "English") + + res = whisper_model.transcribe(file) + segments = whisper_model.extract_text_and_timestamps( res ) + + result = { + 'segments': [] + } + for segment in segments: + reparsed = { + 'start': segment[0], + 'end': segment[1], + 'text': segment[2], + } + result['segments'].append(reparsed) + return result + + def prepare_dataset( files, outdir, language=None, progress=None ): unload_tts() global whisper_model if whisper_model is None: - load_whisper_model() + load_whisper_model(language=language) os.makedirs(outdir, exist_ok=True) @@ -693,7 +708,7 @@ def prepare_dataset( files, outdir, language=None, progress=None ): for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress): print(f"Transcribing file: {file}") - result = whisper_model.transcribe(file, language=language if language else "English") + result = whisper_transcribe(file, language=language) # whisper_model.transcribe(file, language=language if language else "English") results[os.path.basename(file)] = result print(f"Transcribed file: {file}, {len(result['segments'])} found.") @@ -1037,11 +1052,13 @@ def setup_args(): 'defer-tts-load': False, 'device-override': None, 'prune-nonfinal-outputs': True, - 'whisper-model': "base", - 'autoregressive-model': None, 'concurrency-count': 2, 'output-sample-rate': 44100, 'output-volume': 1, + + 'autoregressive-model': None, + 'whisper-model': "base", + 'whisper-cpp': False, 'training-default-halfp': False, 'training-default-bnb': True, @@ -1067,13 +1084,15 @@ def setup_args(): parser.add_argument("--defer-tts-load", default=default_arguments['defer-tts-load'], action='store_true', help="Defers loading TTS model") parser.add_argument("--prune-nonfinal-outputs", default=default_arguments['prune-nonfinal-outputs'], action='store_true', help="Deletes non-final output files on completing a generation") parser.add_argument("--device-override", default=default_arguments['device-override'], help="A device string to override pass through Torch") - parser.add_argument("--whisper-model", default=default_arguments['whisper-model'], help="Specifies which whisper model to use for transcription.") - parser.add_argument("--autoregressive-model", default=default_arguments['autoregressive-model'], help="Specifies which autoregressive model to use for sampling.") parser.add_argument("--sample-batch-size", default=default_arguments['sample-batch-size'], type=int, help="Sets how many batches to use during the autoregressive samples pass") parser.add_argument("--concurrency-count", type=int, default=default_arguments['concurrency-count'], help="How many Gradio events to process at once") parser.add_argument("--output-sample-rate", type=int, default=default_arguments['output-sample-rate'], help="Sample rate to resample the output to (from 24KHz)") parser.add_argument("--output-volume", type=float, default=default_arguments['output-volume'], help="Adjusts volume of output") + parser.add_argument("--autoregressive-model", default=default_arguments['autoregressive-model'], help="Specifies which autoregressive model to use for sampling.") + parser.add_argument("--whisper-model", default=default_arguments['whisper-model'], help="Specifies which whisper model to use for transcription.") + parser.add_argument("--whisper-cpp", default=default_arguments['whisper-cpp'], action='store_true', help="Leverages lightmare/whispercpp for transcription") + parser.add_argument("--training-default-halfp", action='store_true', default=default_arguments['training-default-halfp'], help="Training default: halfp") parser.add_argument("--training-default-bnb", action='store_true', default=default_arguments['training-default-bnb'], help="Training default: bnb") @@ -1103,7 +1122,7 @@ def setup_args(): return args -def update_args( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, defer_tts_load, prune_nonfinal_outputs, device_override, sample_batch_size, concurrency_count, output_sample_rate, output_volume, training_default_halfp, training_default_bnb ): +def update_args( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, defer_tts_load, prune_nonfinal_outputs, device_override, sample_batch_size, concurrency_count, output_sample_rate, output_volume, autoregressive_model, whisper_model, whisper_cpp, training_default_halfp, training_default_bnb ): global args args.listen = listen @@ -1123,6 +1142,11 @@ def update_args( listen, share, check_for_updates, models_from_local_only, low_v args.concurrency_count = concurrency_count args.output_sample_rate = output_sample_rate args.output_volume = output_volume + + args.autoregressive_model = autoregressive_model + args.whisper_model = whisper_model + args.whisper_cpp = whisper_cpp + args.training_default_halfp = training_default_halfp args.training_default_bnb = training_default_bnb @@ -1140,8 +1164,6 @@ def save_args_settings(): 'defer-tts-load': args.defer_tts_load, 'prune-nonfinal-outputs': args.prune_nonfinal_outputs, 'device-override': args.device_override, - 'whisper-model': args.whisper_model, - 'autoregressive-model': args.autoregressive_model, 'sample-batch-size': args.sample_batch_size, 'embed-output-metadata': args.embed_output_metadata, 'latents-lean-and-mean': args.latents_lean_and_mean, @@ -1150,6 +1172,10 @@ def save_args_settings(): 'concurrency-count': args.concurrency_count, 'output-sample-rate': args.output_sample_rate, 'output-volume': args.output_volume, + + 'autoregressive-model': args.autoregressive_model, + 'whisper-model': args.whisper_model, + 'whisper-cpp': args.whisper_cpp, 'training-default-halfp': args.training_default_halfp, 'training-default-bnb': args.training_default_bnb, @@ -1292,9 +1318,7 @@ def update_autoregressive_model(autoregressive_model_path): if not tts: if tts_loading: raise Exception("TTS is still initializing...") - - load_tts( model=autoregressive_model_path ) - return # redundant to proceed onward + return print(f"Loading model: {autoregressive_model_path}") @@ -1348,7 +1372,7 @@ def unload_voicefixer(): do_gc() -def load_whisper_model(name=None, progress=None): +def load_whisper_model(name=None, progress=None, language=b'en'): global whisper_model if not name: @@ -1358,7 +1382,12 @@ def load_whisper_model(name=None, progress=None): save_args_settings() notify_progress(f"Loading Whisper model: {args.whisper_model}", progress) - whisper_model = whisper.load_model(args.whisper_model) + if args.whisper_cpp: + from whispercpp import Whisper + whisper_model = Whisper(name, models_dir='./models/', language=language) + else: + import whisper + whisper_model = whisper.load_model(args.whisper_model) print("Loaded Whisper model") def unload_whisper(): @@ -1372,10 +1401,13 @@ def unload_whisper(): do_gc() +""" def update_whisper_model(name, progress=None): if not name: return + args.whisper_model = name + save_args_settings() global whisper_model if whisper_model: @@ -1383,4 +1415,5 @@ def update_whisper_model(name, progress=None): load_whisper_model(name) else: args.whisper_model = name - save_args_settings() \ No newline at end of file + save_args_settings() +""" \ No newline at end of file diff --git a/src/webui.py b/src/webui.py index 1ebcc43..f479563 100755 --- a/src/webui.py +++ b/src/webui.py @@ -537,7 +537,12 @@ def setup_gradio(): autoregressive_models = get_autoregressive_models() autoregressive_model_dropdown = gr.Dropdown(choices=autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else autoregressive_models[0]) + whisper_model_dropdown = gr.Dropdown(["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large"], label="Whisper Model", value=args.whisper_model) + use_whisper_cpp = gr.Checkbox(label="Use Whisper.cpp", value=args.whisper_cpp) + + exec_inputs = exec_inputs + [ autoregressive_model_dropdown, whisper_model_dropdown, use_whisper_cpp, training_halfp, training_bnb ] + with gr.Row(): autoregressive_models_update_button = gr.Button(value="Refresh Model List") gr.Button(value="Check for Updates").click(check_for_updates) @@ -559,22 +564,21 @@ def setup_gradio(): outputs=autoregressive_model_dropdown, ) - autoregressive_model_dropdown.change( - fn=update_autoregressive_model, - inputs=autoregressive_model_dropdown, - outputs=None - ) - whisper_model_dropdown.change( - fn=update_whisper_model, - inputs=whisper_model_dropdown, - outputs=None - ) - - exec_inputs = exec_inputs + [ training_halfp, training_bnb ] - - for i in exec_inputs: i.change( fn=update_args, inputs=exec_inputs ) + + autoregressive_model_dropdown.change( + fn=update_autoregressive_model, + inputs=autoregressive_model_dropdown, + outputs=None + ) + """ + whisper_model_dropdown.change( + fn=update_whisper_model, + inputs=whisper_model_dropdown, + outputs=None + ) + """ # console_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)