Compare commits
1 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
7c9f55b1de |
66
src/utils.py
66
src/utils.py
|
@ -42,8 +42,6 @@ WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v2"]
|
||||||
WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"]
|
WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"]
|
||||||
WHISPER_BACKENDS = ["openai/whisper", "lightmare/whispercpp", "m-bain/whisperx"]
|
WHISPER_BACKENDS = ["openai/whisper", "lightmare/whispercpp", "m-bain/whisperx"]
|
||||||
|
|
||||||
VOCODERS = ['univnet', 'bigvgan_base_24khz_100band'] #, 'bigvgan_24khz_100band']
|
|
||||||
|
|
||||||
EPOCH_SCHEDULE = [ 9, 18, 25, 33 ]
|
EPOCH_SCHEDULE = [ 9, 18, 25, 33 ]
|
||||||
|
|
||||||
args = None
|
args = None
|
||||||
|
@ -985,9 +983,6 @@ def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, pro
|
||||||
global training_state
|
global training_state
|
||||||
if training_state and training_state.process:
|
if training_state and training_state.process:
|
||||||
return "Training already in progress"
|
return "Training already in progress"
|
||||||
|
|
||||||
# ensure we have the dvae.pth
|
|
||||||
get_model_path('dvae.pth')
|
|
||||||
|
|
||||||
# I don't know if this is still necessary, as it was bitching at me for not doing this, despite it being in a separate process
|
# I don't know if this is still necessary, as it was bitching at me for not doing this, despite it being in a separate process
|
||||||
torch.multiprocessing.freeze_support()
|
torch.multiprocessing.freeze_support()
|
||||||
|
@ -1198,6 +1193,10 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
|
||||||
with open(f'{outdir}/whisper.json', 'w', encoding="utf-8") as f:
|
with open(f'{outdir}/whisper.json', 'w', encoding="utf-8") as f:
|
||||||
f.write(json.dumps(results, indent='\t'))
|
f.write(json.dumps(results, indent='\t'))
|
||||||
|
|
||||||
|
joined = '\n'.join(transcription)
|
||||||
|
with open(f'{outdir}/train.txt', 'w', encoding="utf-8") as f:
|
||||||
|
f.write(joined)
|
||||||
|
|
||||||
unload_whisper()
|
unload_whisper()
|
||||||
|
|
||||||
return f"Processed dataset to: {outdir}\n{joined}"
|
return f"Processed dataset to: {outdir}\n{joined}"
|
||||||
|
@ -1544,7 +1543,7 @@ def setup_args():
|
||||||
'defer-tts-load': False,
|
'defer-tts-load': False,
|
||||||
'device-override': None,
|
'device-override': None,
|
||||||
'prune-nonfinal-outputs': True,
|
'prune-nonfinal-outputs': True,
|
||||||
'vocoder-model': VOCODERS[-1],
|
'use-bigvgan-vocoder': True,
|
||||||
'concurrency-count': 2,
|
'concurrency-count': 2,
|
||||||
'autocalculate-voice-chunk-duration-size': 0,
|
'autocalculate-voice-chunk-duration-size': 0,
|
||||||
'output-sample-rate': 44100,
|
'output-sample-rate': 44100,
|
||||||
|
@ -1581,7 +1580,7 @@ def setup_args():
|
||||||
parser.add_argument("--force-cpu-for-conditioning-latents", default=default_arguments['force-cpu-for-conditioning-latents'], action='store_true', help="Forces computing conditional latents to be done on the CPU (if you constantyl OOM on low chunk counts)")
|
parser.add_argument("--force-cpu-for-conditioning-latents", default=default_arguments['force-cpu-for-conditioning-latents'], action='store_true', help="Forces computing conditional latents to be done on the CPU (if you constantyl OOM on low chunk counts)")
|
||||||
parser.add_argument("--defer-tts-load", default=default_arguments['defer-tts-load'], action='store_true', help="Defers loading TTS model")
|
parser.add_argument("--defer-tts-load", default=default_arguments['defer-tts-load'], action='store_true', help="Defers loading TTS model")
|
||||||
parser.add_argument("--prune-nonfinal-outputs", default=default_arguments['prune-nonfinal-outputs'], action='store_true', help="Deletes non-final output files on completing a generation")
|
parser.add_argument("--prune-nonfinal-outputs", default=default_arguments['prune-nonfinal-outputs'], action='store_true', help="Deletes non-final output files on completing a generation")
|
||||||
parser.add_argument("--vocoder-model", default=default_arguments['vocoder-model'], action='store_true', help="Specifies with vocoder to use")
|
parser.add_argument("--use-bigvgan-vocoder", default=default_arguments['use-bigvgan-vocoder'], action='store_true', help="Uses BigVGAN in place of the default vocoder")
|
||||||
parser.add_argument("--device-override", default=default_arguments['device-override'], help="A device string to override pass through Torch")
|
parser.add_argument("--device-override", default=default_arguments['device-override'], help="A device string to override pass through Torch")
|
||||||
parser.add_argument("--sample-batch-size", default=default_arguments['sample-batch-size'], type=int, help="Sets how many batches to use during the autoregressive samples pass")
|
parser.add_argument("--sample-batch-size", default=default_arguments['sample-batch-size'], type=int, help="Sets how many batches to use during the autoregressive samples pass")
|
||||||
parser.add_argument("--concurrency-count", type=int, default=default_arguments['concurrency-count'], help="How many Gradio events to process at once")
|
parser.add_argument("--concurrency-count", type=int, default=default_arguments['concurrency-count'], help="How many Gradio events to process at once")
|
||||||
|
@ -1625,7 +1624,7 @@ def setup_args():
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def update_args( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, defer_tts_load, prune_nonfinal_outputs, device_override, sample_batch_size, concurrency_count, autocalculate_voice_chunk_duration_size, output_volume, autoregressive_model, vocoder_model, whisper_backend, whisper_model, training_default_halfp, training_default_bnb ):
|
def update_args( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, defer_tts_load, prune_nonfinal_outputs, use_bigvgan_vocoder, device_override, sample_batch_size, concurrency_count, autocalculate_voice_chunk_duration_size, output_volume, autoregressive_model, whisper_backend, whisper_model, training_default_halfp, training_default_bnb ):
|
||||||
global args
|
global args
|
||||||
|
|
||||||
args.listen = listen
|
args.listen = listen
|
||||||
|
@ -1636,6 +1635,7 @@ def update_args( listen, share, check_for_updates, models_from_local_only, low_v
|
||||||
args.force_cpu_for_conditioning_latents = force_cpu_for_conditioning_latents
|
args.force_cpu_for_conditioning_latents = force_cpu_for_conditioning_latents
|
||||||
args.defer_tts_load = defer_tts_load
|
args.defer_tts_load = defer_tts_load
|
||||||
args.prune_nonfinal_outputs = prune_nonfinal_outputs
|
args.prune_nonfinal_outputs = prune_nonfinal_outputs
|
||||||
|
args.use_bigvgan_vocoder = use_bigvgan_vocoder
|
||||||
args.device_override = device_override
|
args.device_override = device_override
|
||||||
args.sample_batch_size = sample_batch_size
|
args.sample_batch_size = sample_batch_size
|
||||||
args.embed_output_metadata = embed_output_metadata
|
args.embed_output_metadata = embed_output_metadata
|
||||||
|
@ -1648,7 +1648,6 @@ def update_args( listen, share, check_for_updates, models_from_local_only, low_v
|
||||||
args.output_volume = output_volume
|
args.output_volume = output_volume
|
||||||
|
|
||||||
args.autoregressive_model = autoregressive_model
|
args.autoregressive_model = autoregressive_model
|
||||||
args.vocoder_model = vocoder_model
|
|
||||||
args.whisper_backend = whisper_backend
|
args.whisper_backend = whisper_backend
|
||||||
args.whisper_model = whisper_model
|
args.whisper_model = whisper_model
|
||||||
|
|
||||||
|
@ -1668,6 +1667,7 @@ def save_args_settings():
|
||||||
'force-cpu-for-conditioning-latents': args.force_cpu_for_conditioning_latents,
|
'force-cpu-for-conditioning-latents': args.force_cpu_for_conditioning_latents,
|
||||||
'defer-tts-load': args.defer_tts_load,
|
'defer-tts-load': args.defer_tts_load,
|
||||||
'prune-nonfinal-outputs': args.prune_nonfinal_outputs,
|
'prune-nonfinal-outputs': args.prune_nonfinal_outputs,
|
||||||
|
'use-bigvgan-vocoder': args.use_bigvgan_vocoder,
|
||||||
'device-override': args.device_override,
|
'device-override': args.device_override,
|
||||||
'sample-batch-size': args.sample_batch_size,
|
'sample-batch-size': args.sample_batch_size,
|
||||||
'embed-output-metadata': args.embed_output_metadata,
|
'embed-output-metadata': args.embed_output_metadata,
|
||||||
|
@ -1680,7 +1680,6 @@ def save_args_settings():
|
||||||
'output-volume': args.output_volume,
|
'output-volume': args.output_volume,
|
||||||
|
|
||||||
'autoregressive-model': args.autoregressive_model,
|
'autoregressive-model': args.autoregressive_model,
|
||||||
'vocoder-model': args.vocoder_model,
|
|
||||||
'whisper-backend': args.whisper_backend,
|
'whisper-backend': args.whisper_backend,
|
||||||
'whisper-model': args.whisper_model,
|
'whisper-model': args.whisper_model,
|
||||||
|
|
||||||
|
@ -1796,11 +1795,11 @@ def load_tts( restart=False, model=None ):
|
||||||
if model:
|
if model:
|
||||||
args.autoregressive_model = model
|
args.autoregressive_model = model
|
||||||
|
|
||||||
print(f"Loading TorToiSe... (AR: {args.autoregressive_model}, vocoder: {args.vocoder_model})")
|
print(f"Loading TorToiSe... (using model: {args.autoregressive_model})")
|
||||||
|
|
||||||
tts_loading = True
|
tts_loading = True
|
||||||
try:
|
try:
|
||||||
tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=args.autoregressive_model, vocoder_model=args.vocoder_model)
|
tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=args.autoregressive_model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
tts = TextToSpeech(minor_optimizations=not args.low_vram)
|
tts = TextToSpeech(minor_optimizations=not args.low_vram)
|
||||||
load_autoregressive_model(args.autoregressive_model)
|
load_autoregressive_model(args.autoregressive_model)
|
||||||
|
@ -1848,32 +1847,35 @@ def update_autoregressive_model(autoregressive_model_path):
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f"Loading model: {autoregressive_model_path}")
|
print(f"Loading model: {autoregressive_model_path}")
|
||||||
tts.load_autoregressive_model(autoregressive_model_path)
|
|
||||||
|
if hasattr(tts, 'load_autoregressive_model') and tts.load_autoregressive_model(autoregressive_model_path):
|
||||||
|
tts.load_autoregressive_model(autoregressive_model_path)
|
||||||
|
# polyfill in case a user did NOT update the packages
|
||||||
|
# this shouldn't happen anymore, as I just clone mrq/tortoise-tts, and inject it into sys.path
|
||||||
|
else:
|
||||||
|
from tortoise.models.autoregressive import UnifiedVoice
|
||||||
|
|
||||||
|
tts.autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os.path.exists(autoregressive_model_path) else get_model_path('autoregressive.pth', tts.models_dir)
|
||||||
|
|
||||||
|
del tts.autoregressive
|
||||||
|
tts.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
|
||||||
|
model_dim=1024,
|
||||||
|
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
|
||||||
|
train_solo_embeddings=False).cpu().eval()
|
||||||
|
tts.autoregressive.load_state_dict(torch.load(tts.autoregressive_model_path))
|
||||||
|
tts.autoregressive.post_init_gpt2_config(kv_cache=tts.use_kv_cache)
|
||||||
|
if tts.preloaded_tensors:
|
||||||
|
tts.autoregressive = tts.autoregressive.to(tts.device)
|
||||||
|
|
||||||
|
if not hasattr(tts, 'autoregressive_model_hash'):
|
||||||
|
tts.autoregressive_model_hash = hash_file(autoregressive_model_path)
|
||||||
|
|
||||||
print(f"Loaded model: {tts.autoregressive_model_path}")
|
print(f"Loaded model: {tts.autoregressive_model_path}")
|
||||||
|
|
||||||
do_gc()
|
do_gc()
|
||||||
|
|
||||||
return autoregressive_model_path
|
return autoregressive_model_path
|
||||||
|
|
||||||
def update_vocoder_model(vocoder_model):
|
|
||||||
args.vocoder_model = vocoder_model
|
|
||||||
save_args_settings()
|
|
||||||
print(f'Stored vocoder model to settings: {vocoder_model}')
|
|
||||||
|
|
||||||
global tts
|
|
||||||
if not tts:
|
|
||||||
if tts_loading:
|
|
||||||
raise Exception("TTS is still initializing...")
|
|
||||||
return
|
|
||||||
|
|
||||||
print(f"Loading model: {vocoder_model}")
|
|
||||||
tts.load_vocoder_model(vocoder_model)
|
|
||||||
print(f"Loaded model: {tts.vocoder_model}")
|
|
||||||
|
|
||||||
do_gc()
|
|
||||||
|
|
||||||
return vocoder_model
|
|
||||||
|
|
||||||
def load_voicefixer(restart=False):
|
def load_voicefixer(restart=False):
|
||||||
global voicefixer
|
global voicefixer
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user