Compare commits
5 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
1e2436aac9 | ||
|
f657f30e2b | ||
|
4f123910fb | ||
|
9ca5192309 | ||
|
079cd32074 |
|
@ -1,4 +1,4 @@
|
||||||
git+https://github.com/openai/whisper.git
|
git+https://github.com/m-bain/whisperx.git
|
||||||
more-itertools
|
more-itertools
|
||||||
ffmpeg-python
|
ffmpeg-python
|
||||||
gradio
|
gradio
|
||||||
|
|
322
src/utils.py
322
src/utils.py
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
if 'XDG_CACHE_HOME' not in os.environ:
|
if 'XDG_CACHE_HOME' not in os.environ:
|
||||||
os.environ['XDG_CACHE_HOME'] = os.path.realpath(os.path.join(os.getcwd(), './models/'))
|
os.environ['XDG_CACHE_HOME'] = os.path.realpath(os.path.join(os.getcwd(), './models/'))
|
||||||
|
|
||||||
|
@ -27,6 +28,7 @@ import music_tag
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import gradio.utils
|
import gradio.utils
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import whisperx
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
@ -36,9 +38,9 @@ from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_
|
||||||
from tortoise.utils.text import split_and_recombine_text
|
from tortoise.utils.text import split_and_recombine_text
|
||||||
from tortoise.utils.device import get_device_name, set_device_name
|
from tortoise.utils.device import get_device_name, set_device_name
|
||||||
|
|
||||||
MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"
|
MODELS[
|
||||||
WHISPER_MODELS = ["tiny", "base", "small", "medium", "large"]
|
'dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"
|
||||||
WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"]
|
WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v2"]
|
||||||
EPOCH_SCHEDULE = [9, 18, 25, 33]
|
EPOCH_SCHEDULE = [9, 18, 25, 33]
|
||||||
|
|
||||||
args = None
|
args = None
|
||||||
|
@ -49,6 +51,7 @@ voicefixer = None
|
||||||
whisper_model = None
|
whisper_model = None
|
||||||
training_state = None
|
training_state = None
|
||||||
|
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
text,
|
text,
|
||||||
delimiter,
|
delimiter,
|
||||||
|
@ -110,13 +113,16 @@ def generate(
|
||||||
if voice_samples and len(voice_samples) > 0:
|
if voice_samples and len(voice_samples) > 0:
|
||||||
sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu()
|
sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu()
|
||||||
|
|
||||||
conditioning_latents = tts.get_conditioning_latents(voice_samples, return_mels=not args.latents_lean_and_mean, progress=progress, slices=voice_latents_chunks, force_cpu=args.force_cpu_for_conditioning_latents)
|
conditioning_latents = tts.get_conditioning_latents(voice_samples, return_mels=not args.latents_lean_and_mean,
|
||||||
|
progress=progress, slices=voice_latents_chunks,
|
||||||
|
force_cpu=args.force_cpu_for_conditioning_latents)
|
||||||
if len(conditioning_latents) == 4:
|
if len(conditioning_latents) == 4:
|
||||||
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
|
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
|
||||||
|
|
||||||
if voice != "microphone":
|
if voice != "microphone":
|
||||||
if hasattr(tts, 'autoregressive_model_hash'):
|
if hasattr(tts, 'autoregressive_model_hash'):
|
||||||
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth')
|
torch.save(conditioning_latents,
|
||||||
|
f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth')
|
||||||
else:
|
else:
|
||||||
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
|
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
|
||||||
voice_samples = None
|
voice_samples = None
|
||||||
|
@ -132,10 +138,10 @@ def generate(
|
||||||
seed = None
|
seed = None
|
||||||
|
|
||||||
if conditioning_latents is not None and len(conditioning_latents) == 2 and cvvp_weight > 0:
|
if conditioning_latents is not None and len(conditioning_latents) == 2 and cvvp_weight > 0:
|
||||||
print("Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents.")
|
print(
|
||||||
|
"Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents.")
|
||||||
cvvp_weight = 0
|
cvvp_weight = 0
|
||||||
|
|
||||||
|
|
||||||
settings = {
|
settings = {
|
||||||
'temperature': float(temperature),
|
'temperature': float(temperature),
|
||||||
|
|
||||||
|
@ -199,7 +205,8 @@ def generate(
|
||||||
beta=8.555504641634386,
|
beta=8.555504641634386,
|
||||||
)
|
)
|
||||||
|
|
||||||
volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None
|
volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume,
|
||||||
|
gain_type="amplitude") if args.output_volume != 1 else None
|
||||||
|
|
||||||
idx = 0
|
idx = 0
|
||||||
idx_cache = {}
|
idx_cache = {}
|
||||||
|
@ -380,7 +387,6 @@ def generate(
|
||||||
with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f:
|
with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f:
|
||||||
f.write(json.dumps(info, indent='\t'))
|
f.write(json.dumps(info, indent='\t'))
|
||||||
|
|
||||||
|
|
||||||
if voice and voice != "random" and conditioning_latents is not None:
|
if voice and voice != "random" and conditioning_latents is not None:
|
||||||
latents_path = f'{get_voice_dir()}/{voice}/cond_latents.pth'
|
latents_path = f'{get_voice_dir()}/{voice}/cond_latents.pth'
|
||||||
|
|
||||||
|
@ -428,10 +434,12 @@ def generate(
|
||||||
stats,
|
stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def cancel_generate():
|
def cancel_generate():
|
||||||
import tortoise.api
|
import tortoise.api
|
||||||
tortoise.api.STOP_SIGNAL = True
|
tortoise.api.STOP_SIGNAL = True
|
||||||
|
|
||||||
|
|
||||||
def hash_file(path, algo="md5", buffer_size=0):
|
def hash_file(path, algo="md5", buffer_size=0):
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
|
@ -458,6 +466,7 @@ def hash_file(path, algo="md5", buffer_size=0):
|
||||||
|
|
||||||
return "{0}".format(hash.hexdigest())
|
return "{0}".format(hash.hexdigest())
|
||||||
|
|
||||||
|
|
||||||
def update_baseline_for_latents_chunks(voice):
|
def update_baseline_for_latents_chunks(voice):
|
||||||
path = f'{get_voice_dir()}/{voice}/'
|
path = f'{get_voice_dir()}/{voice}/'
|
||||||
if not os.path.isdir(path):
|
if not os.path.isdir(path):
|
||||||
|
@ -481,6 +490,7 @@ def update_baseline_for_latents_chunks( voice ):
|
||||||
return int(total_duration / total) if total > 0 else 1
|
return int(total_duration / total) if total > 0 else 1
|
||||||
return int(total_duration / args.autocalculate_voice_chunk_duration_size) if total_duration > 0 else 1
|
return int(total_duration / args.autocalculate_voice_chunk_duration_size) if total_duration > 0 else 1
|
||||||
|
|
||||||
|
|
||||||
def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
|
def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
|
||||||
global tts
|
global tts
|
||||||
global args
|
global args
|
||||||
|
@ -498,18 +508,22 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm
|
||||||
if voice_samples is None:
|
if voice_samples is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
conditioning_latents = tts.get_conditioning_latents(voice_samples, return_mels=not args.latents_lean_and_mean, progress=progress, slices=voice_latents_chunks, force_cpu=args.force_cpu_for_conditioning_latents)
|
conditioning_latents = tts.get_conditioning_latents(voice_samples, return_mels=not args.latents_lean_and_mean,
|
||||||
|
progress=progress, slices=voice_latents_chunks,
|
||||||
|
force_cpu=args.force_cpu_for_conditioning_latents)
|
||||||
|
|
||||||
if len(conditioning_latents) == 4:
|
if len(conditioning_latents) == 4:
|
||||||
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
|
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
|
||||||
|
|
||||||
if hasattr(tts, 'autoregressive_model_hash'):
|
if hasattr(tts, 'autoregressive_model_hash'):
|
||||||
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth')
|
torch.save(conditioning_latents,
|
||||||
|
f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth')
|
||||||
else:
|
else:
|
||||||
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
|
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
|
||||||
|
|
||||||
return voice
|
return voice
|
||||||
|
|
||||||
|
|
||||||
# superfluous, but it cleans up some things
|
# superfluous, but it cleans up some things
|
||||||
class TrainingState():
|
class TrainingState():
|
||||||
def __init__(self, config_path, keep_x_past_datasets=0, start=True, gpus=1):
|
def __init__(self, config_path, keep_x_past_datasets=0, start=True, gpus=1):
|
||||||
|
@ -580,7 +594,8 @@ class TrainingState():
|
||||||
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['./train.sh', str(int(gpus)), config_path]
|
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['./train.sh', str(int(gpus)), config_path]
|
||||||
|
|
||||||
print("Spawning process: ", " ".join(self.cmd))
|
print("Spawning process: ", " ".join(self.cmd))
|
||||||
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
||||||
|
universal_newlines=True)
|
||||||
|
|
||||||
def load_losses(self, update=False):
|
def load_losses(self, update=False):
|
||||||
if not os.path.isdir(f'{self.dataset_dir}/tb_logger/'):
|
if not os.path.isdir(f'{self.dataset_dir}/tb_logger/'):
|
||||||
|
@ -599,7 +614,8 @@ class TrainingState():
|
||||||
self.statistics = []
|
self.statistics = []
|
||||||
|
|
||||||
if use_tensorboard:
|
if use_tensorboard:
|
||||||
logs = sorted([f'{self.dataset_dir}/tb_logger/{d}' for d in os.listdir(f'{self.dataset_dir}/tb_logger/') if d[:6] == "events" ])
|
logs = sorted([f'{self.dataset_dir}/tb_logger/{d}' for d in os.listdir(f'{self.dataset_dir}/tb_logger/') if
|
||||||
|
d[:6] == "events"])
|
||||||
if update:
|
if update:
|
||||||
logs = [logs[-1]]
|
logs = [logs[-1]]
|
||||||
|
|
||||||
|
@ -893,6 +909,7 @@ class TrainingState():
|
||||||
message,
|
message,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)):
|
def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)):
|
||||||
global training_state
|
global training_state
|
||||||
if training_state and training_state.process:
|
if training_state and training_state.process:
|
||||||
|
@ -911,7 +928,8 @@ def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, pro
|
||||||
if training_state.killed:
|
if training_state.killed:
|
||||||
return
|
return
|
||||||
|
|
||||||
result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_datasets=keep_x_past_datasets, progress=progress )
|
result, percent, message = training_state.parse(line=line, verbose=verbose,
|
||||||
|
keep_x_past_datasets=keep_x_past_datasets, progress=progress)
|
||||||
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
|
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
|
||||||
if result:
|
if result:
|
||||||
yield result
|
yield result
|
||||||
|
@ -924,6 +942,7 @@ def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, pro
|
||||||
return_code = training_state.process.wait()
|
return_code = training_state.process.wait()
|
||||||
training_state = None
|
training_state = None
|
||||||
|
|
||||||
|
|
||||||
def update_training_dataplot(config_path=None):
|
def update_training_dataplot(config_path=None):
|
||||||
global training_state
|
global training_state
|
||||||
update = None
|
update = None
|
||||||
|
@ -932,22 +951,29 @@ def update_training_dataplot(config_path=None):
|
||||||
if config_path:
|
if config_path:
|
||||||
training_state = TrainingState(config_path=config_path, start=False)
|
training_state = TrainingState(config_path=config_path, start=False)
|
||||||
if training_state.statistics:
|
if training_state.statistics:
|
||||||
update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics), x_lim=[0,training_state.its], x="step", y="value", title="Training Metrics", color="type", tooltip=['step', 'value', 'type'], width=600, height=350,)
|
update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics),
|
||||||
|
x_lim=[0, training_state.its], x="step", y="value",
|
||||||
|
title="Training Metrics", color="type", tooltip=['step', 'value', 'type'],
|
||||||
|
width=600, height=350, )
|
||||||
del training_state
|
del training_state
|
||||||
training_state = None
|
training_state = None
|
||||||
elif training_state.statistics:
|
elif training_state.statistics:
|
||||||
training_state.load_losses()
|
training_state.load_losses()
|
||||||
update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics), x_lim=[0,training_state.its], x="step", y="value", title="Training Metrics", color="type", tooltip=['step', 'value', 'type'], width=600, height=350,)
|
update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics), x_lim=[0, training_state.its],
|
||||||
|
x="step", y="value", title="Training Metrics", color="type",
|
||||||
|
tooltip=['step', 'value', 'type'], width=600, height=350, )
|
||||||
|
|
||||||
return update
|
return update
|
||||||
|
|
||||||
|
|
||||||
def reconnect_training(verbose=False, progress=gr.Progress(track_tqdm=True)):
|
def reconnect_training(verbose=False, progress=gr.Progress(track_tqdm=True)):
|
||||||
global training_state
|
global training_state
|
||||||
if not training_state or not training_state.process:
|
if not training_state or not training_state.process:
|
||||||
return "Training not in progress"
|
return "Training not in progress"
|
||||||
|
|
||||||
for line in iter(training_state.process.stdout.readline, ""):
|
for line in iter(training_state.process.stdout.readline, ""):
|
||||||
result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_datasets=keep_x_past_datasets, progress=progress )
|
result, percent, message = training_state.parse(line=line, verbose=verbose,
|
||||||
|
keep_x_past_datasets=keep_x_past_datasets, progress=progress)
|
||||||
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
|
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
|
||||||
if result:
|
if result:
|
||||||
yield result
|
yield result
|
||||||
|
@ -955,6 +981,7 @@ def reconnect_training(verbose=False, progress=gr.Progress(track_tqdm=True)):
|
||||||
if progress is not None and message:
|
if progress is not None and message:
|
||||||
progress(percent, message)
|
progress(percent, message)
|
||||||
|
|
||||||
|
|
||||||
def stop_training():
|
def stop_training():
|
||||||
global training_state
|
global training_state
|
||||||
if training_state is None:
|
if training_state is None:
|
||||||
|
@ -965,7 +992,8 @@ def stop_training():
|
||||||
children = []
|
children = []
|
||||||
# wrapped in a try/catch in case for some reason this fails outside of Linux
|
# wrapped in a try/catch in case for some reason this fails outside of Linux
|
||||||
try:
|
try:
|
||||||
children = [p.info for p in psutil.process_iter(attrs=['pid', 'name', 'cmdline']) if './src/train.py' in p.info['cmdline']]
|
children = [p.info for p in psutil.process_iter(attrs=['pid', 'name', 'cmdline']) if
|
||||||
|
'./src/train.py' in p.info['cmdline']]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -981,10 +1009,12 @@ def stop_training():
|
||||||
print("Killed training process.")
|
print("Killed training process.")
|
||||||
return f"Training cancelled: {return_code}"
|
return f"Training cancelled: {return_code}"
|
||||||
|
|
||||||
|
|
||||||
def get_halfp_model_path():
|
def get_halfp_model_path():
|
||||||
autoregressive_model_path = get_model_path('autoregressive.pth')
|
autoregressive_model_path = get_model_path('autoregressive.pth')
|
||||||
return autoregressive_model_path.replace(".pth", "_half.pth")
|
return autoregressive_model_path.replace(".pth", "_half.pth")
|
||||||
|
|
||||||
|
|
||||||
def convert_to_halfp():
|
def convert_to_halfp():
|
||||||
autoregressive_model_path = get_model_path('autoregressive.pth')
|
autoregressive_model_path = get_model_path('autoregressive.pth')
|
||||||
print(f'Converting model to half precision: {autoregressive_model_path}')
|
print(f'Converting model to half precision: {autoregressive_model_path}')
|
||||||
|
@ -996,66 +1026,59 @@ def convert_to_halfp():
|
||||||
torch.save(model, outfile)
|
torch.save(model, outfile)
|
||||||
print(f'Converted model to half precision: {outfile}')
|
print(f'Converted model to half precision: {outfile}')
|
||||||
|
|
||||||
def whisper_transcribe( file, language=None ):
|
|
||||||
# shouldn't happen, but it's for safety
|
|
||||||
if not whisper_model:
|
|
||||||
load_whisper_model(language=language)
|
|
||||||
|
|
||||||
if not args.whisper_cpp:
|
|
||||||
if not language:
|
|
||||||
language = None
|
|
||||||
|
|
||||||
return whisper_model.transcribe(file, language=language)
|
|
||||||
|
|
||||||
res = whisper_model.transcribe(file)
|
|
||||||
segments = whisper_model.extract_text_and_timestamps( res )
|
|
||||||
|
|
||||||
result = {
|
|
||||||
'segments': []
|
|
||||||
}
|
|
||||||
for segment in segments:
|
|
||||||
reparsed = {
|
|
||||||
'start': segment[0] / 100.0,
|
|
||||||
'end': segment[1] / 100.0,
|
|
||||||
'text': segment[2],
|
|
||||||
}
|
|
||||||
result['segments'].append(reparsed)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_dataset(files, outdir, language=None, progress=None):
|
def prepare_dataset(files, outdir, language=None, progress=None):
|
||||||
unload_tts()
|
unload_tts()
|
||||||
|
|
||||||
global whisper_model
|
global whisper_model
|
||||||
if whisper_model is None:
|
if whisper_model is None:
|
||||||
load_whisper_model(language=language)
|
load_whisper_model()
|
||||||
|
|
||||||
os.makedirs(outdir, exist_ok=True)
|
os.makedirs(outdir, exist_ok=True)
|
||||||
|
|
||||||
|
idx = 0
|
||||||
results = {}
|
results = {}
|
||||||
transcription = []
|
transcription = []
|
||||||
|
|
||||||
|
idx = 0
|
||||||
|
results = {}
|
||||||
|
transcription = []
|
||||||
|
|
||||||
|
if (torch.cuda.is_available()):
|
||||||
|
device = "cuda"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
|
for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
|
||||||
basename = os.path.basename(file)
|
print(f"Transcribing file: {file}")
|
||||||
result = whisper_transcribe(file, language=language)
|
|
||||||
results[basename] = result
|
result = whisper_model.transcribe(file)
|
||||||
|
|
||||||
|
print(result["segments"]) # before alignment
|
||||||
|
|
||||||
|
# load alignment model and metadata
|
||||||
|
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
|
||||||
|
|
||||||
|
# align whisper output
|
||||||
|
result_aligned = whisperx.align(result["segments"], model_a, metadata, file, device)
|
||||||
|
|
||||||
|
print(result_aligned["segments"]) # after alignment
|
||||||
|
print(result_aligned["word_segments"]) # after alignment
|
||||||
|
|
||||||
|
results[os.path.basename(file)] = result
|
||||||
|
|
||||||
print(f"Transcribed file: {file}, {len(result['segments'])} found.")
|
print(f"Transcribed file: {file}, {len(result['segments'])} found.")
|
||||||
|
|
||||||
waveform, sampling_rate = torchaudio.load(file)
|
waveform, sampling_rate = torchaudio.load(file)
|
||||||
num_channels, num_frames = waveform.shape
|
num_channels, num_frames = waveform.shape
|
||||||
|
|
||||||
idx = 0
|
for segment in result[
|
||||||
for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
|
'segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
|
||||||
start = int(segment['start'] * sampling_rate)
|
start = int(segment['start'] * sampling_rate)
|
||||||
end = int(segment['end'] * sampling_rate)
|
end = int(segment['end'] * sampling_rate)
|
||||||
|
|
||||||
sliced_waveform = waveform[:, start:end]
|
sliced_waveform = waveform[:, start:end]
|
||||||
sliced_name = basename.replace(".wav", f"_{pad(idx, 4)}.wav")
|
sliced_name = f"{pad(idx, 4)}.wav"
|
||||||
|
|
||||||
if not torch.any(sliced_waveform < 0):
|
|
||||||
print(f"Error with {sliced_name}, skipping...")
|
|
||||||
continue
|
|
||||||
|
|
||||||
torchaudio.save(f"{outdir}/{sliced_name}", sliced_waveform, sampling_rate)
|
torchaudio.save(f"{outdir}/{sliced_name}", sliced_waveform, sampling_rate)
|
||||||
|
|
||||||
|
@ -1076,14 +1099,19 @@ def prepare_dataset( files, outdir, language=None, progress=None ):
|
||||||
|
|
||||||
return f"Processed dataset to: {outdir}\n{joined}"
|
return f"Processed dataset to: {outdir}\n{joined}"
|
||||||
|
|
||||||
|
|
||||||
def calc_iterations(epochs, lines, batch_size):
|
def calc_iterations(epochs, lines, batch_size):
|
||||||
iterations = int(epochs * lines / float(batch_size))
|
iterations = int(epochs * lines / float(batch_size))
|
||||||
return iterations
|
return iterations
|
||||||
|
|
||||||
|
|
||||||
def schedule_learning_rate(iterations, schedule=EPOCH_SCHEDULE):
|
def schedule_learning_rate(iterations, schedule=EPOCH_SCHEDULE):
|
||||||
return [int(iterations * d) for d in schedule]
|
return [int(iterations * d) for d in schedule]
|
||||||
|
|
||||||
def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, resume_path, half_p, bnb, workers, source_model, voice ):
|
|
||||||
|
def optimize_training_settings(epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size,
|
||||||
|
gradient_accumulation_size, print_rate, save_rate, resume_path, half_p, bnb, workers,
|
||||||
|
source_model, voice):
|
||||||
name = f"{voice}-finetune"
|
name = f"{voice}-finetune"
|
||||||
dataset_name = f"{voice}-train"
|
dataset_name = f"{voice}-train"
|
||||||
dataset_path = f"./training/{voice}/train.txt"
|
dataset_path = f"./training/{voice}/train.txt"
|
||||||
|
@ -1102,7 +1130,8 @@ def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learni
|
||||||
if batch_size % lines != 0:
|
if batch_size % lines != 0:
|
||||||
nearest_slice = int(lines / batch_size) + 1
|
nearest_slice = int(lines / batch_size) + 1
|
||||||
batch_size = int(lines / nearest_slice)
|
batch_size = int(lines / nearest_slice)
|
||||||
messages.append(f"Batch size not neatly divisible by dataset size, adjusting batch size to: {batch_size} ({nearest_slice} steps per epoch)")
|
messages.append(
|
||||||
|
f"Batch size not neatly divisible by dataset size, adjusting batch size to: {batch_size} ({nearest_slice} steps per epoch)")
|
||||||
|
|
||||||
if gradient_accumulation_size == 0:
|
if gradient_accumulation_size == 0:
|
||||||
gradient_accumulation_size = 1
|
gradient_accumulation_size = 1
|
||||||
|
@ -1112,13 +1141,15 @@ def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learni
|
||||||
if gradient_accumulation_size == 0:
|
if gradient_accumulation_size == 0:
|
||||||
gradient_accumulation_size = 1
|
gradient_accumulation_size = 1
|
||||||
|
|
||||||
messages.append(f"Gradient accumulation size is too large for a given batch size, clamping gradient accumulation size to: {gradient_accumulation_size}")
|
messages.append(
|
||||||
|
f"Gradient accumulation size is too large for a given batch size, clamping gradient accumulation size to: {gradient_accumulation_size}")
|
||||||
elif batch_size % gradient_accumulation_size != 0:
|
elif batch_size % gradient_accumulation_size != 0:
|
||||||
gradient_accumulation_size = int(batch_size / gradient_accumulation_size)
|
gradient_accumulation_size = int(batch_size / gradient_accumulation_size)
|
||||||
if gradient_accumulation_size == 0:
|
if gradient_accumulation_size == 0:
|
||||||
gradient_accumulation_size = 1
|
gradient_accumulation_size = 1
|
||||||
|
|
||||||
messages.append(f"Batch size is not evenly divisible by the gradient accumulation size, adjusting gradient accumulation size to: {gradient_accumulation_size}")
|
messages.append(
|
||||||
|
f"Batch size is not evenly divisible by the gradient accumulation size, adjusting gradient accumulation size to: {gradient_accumulation_size}")
|
||||||
|
|
||||||
iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size)
|
iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size)
|
||||||
|
|
||||||
|
@ -1140,13 +1171,15 @@ def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learni
|
||||||
if half_p:
|
if half_p:
|
||||||
if bnb:
|
if bnb:
|
||||||
half_p = False
|
half_p = False
|
||||||
messages.append("Half Precision requested, but BitsAndBytes is also requested. Due to redundancies, disabling half precision...")
|
messages.append(
|
||||||
|
"Half Precision requested, but BitsAndBytes is also requested. Due to redundancies, disabling half precision...")
|
||||||
else:
|
else:
|
||||||
messages.append("Half Precision requested. Please note this is ! EXPERIMENTAL !")
|
messages.append("Half Precision requested. Please note this is ! EXPERIMENTAL !")
|
||||||
if not os.path.exists(get_halfp_model_path()):
|
if not os.path.exists(get_halfp_model_path()):
|
||||||
convert_to_halfp()
|
convert_to_halfp()
|
||||||
|
|
||||||
messages.append(f"For {epochs} epochs with {lines} lines in batches of {batch_size}, iterating for {iterations} steps ({int(iterations / epochs)} steps per epoch)")
|
messages.append(
|
||||||
|
f"For {epochs} epochs with {lines} lines in batches of {batch_size}, iterating for {iterations} steps ({int(iterations / epochs)} steps per epoch)")
|
||||||
|
|
||||||
return (
|
return (
|
||||||
learning_rate,
|
learning_rate,
|
||||||
|
@ -1160,7 +1193,11 @@ def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learni
|
||||||
messages
|
messages
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_schedule=None, batch_size=None, gradient_accumulation_size=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None, half_p=None, bnb=None, workers=None, source_model=None ):
|
|
||||||
|
def save_training_settings(iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_schedule=None,
|
||||||
|
batch_size=None, gradient_accumulation_size=None, print_rate=None, save_rate=None, name=None,
|
||||||
|
dataset_name=None, dataset_path=None, validation_name=None, validation_path=None,
|
||||||
|
output_name=None, resume_path=None, half_p=None, bnb=None, workers=None, source_model=None):
|
||||||
if not source_model:
|
if not source_model:
|
||||||
source_model = f"./models/tortoise/autoregressive{'_half' if half_p else ''}.pth"
|
source_model = f"./models/tortoise/autoregressive{'_half' if half_p else ''}.pth"
|
||||||
|
|
||||||
|
@ -1201,7 +1238,6 @@ def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weig
|
||||||
if not output_name:
|
if not output_name:
|
||||||
output_name = f'{settings["name"]}.yaml'
|
output_name = f'{settings["name"]}.yaml'
|
||||||
|
|
||||||
|
|
||||||
with open(f'./models/.template.yaml', 'r', encoding="utf-8") as f:
|
with open(f'./models/.template.yaml', 'r', encoding="utf-8") as f:
|
||||||
yaml = f.read()
|
yaml = f.read()
|
||||||
|
|
||||||
|
@ -1217,6 +1253,7 @@ def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weig
|
||||||
|
|
||||||
return f"Training settings saved to: {outfile}"
|
return f"Training settings saved to: {outfile}"
|
||||||
|
|
||||||
|
|
||||||
def import_voices(files, saveAs=None, progress=None):
|
def import_voices(files, saveAs=None, progress=None):
|
||||||
global args
|
global args
|
||||||
|
|
||||||
|
@ -1282,13 +1319,16 @@ def import_voices(files, saveAs=None, progress=None):
|
||||||
|
|
||||||
print(f"Imported voice to {path}")
|
print(f"Imported voice to {path}")
|
||||||
|
|
||||||
|
|
||||||
def get_voice_list(dir=get_voice_dir(), append_defaults=False):
|
def get_voice_list(dir=get_voice_dir(), append_defaults=False):
|
||||||
os.makedirs(dir, exist_ok=True)
|
os.makedirs(dir, exist_ok=True)
|
||||||
res = sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ])
|
res = sorted([d for d in os.listdir(dir) if
|
||||||
|
os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0])
|
||||||
if append_defaults:
|
if append_defaults:
|
||||||
res = res + ["random", "microphone"]
|
res = res + ["random", "microphone"]
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def get_autoregressive_models(dir="./models/finetunes/", prefixed=False):
|
def get_autoregressive_models(dir="./models/finetunes/", prefixed=False):
|
||||||
os.makedirs(dir, exist_ok=True)
|
os.makedirs(dir, exist_ok=True)
|
||||||
base = [get_model_path('autoregressive.pth')]
|
base = [get_model_path('autoregressive.pth')]
|
||||||
|
@ -1316,11 +1356,16 @@ def get_autoregressive_models(dir="./models/finetunes/", prefixed=False):
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def get_dataset_list(dir="./training/"):
|
def get_dataset_list(dir="./training/"):
|
||||||
return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 and "train.txt" in os.listdir(os.path.join(dir, d)) ])
|
return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(
|
||||||
|
os.listdir(os.path.join(dir, d))) > 0 and "train.txt" in os.listdir(os.path.join(dir, d))])
|
||||||
|
|
||||||
|
|
||||||
def get_training_list(dir="./training/"):
|
def get_training_list(dir="./training/"):
|
||||||
return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 and "train.yaml" in os.listdir(os.path.join(dir, d)) ])
|
return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(
|
||||||
|
os.listdir(os.path.join(dir, d))) > 0 and "train.yaml" in os.listdir(os.path.join(dir, d))])
|
||||||
|
|
||||||
|
|
||||||
def do_gc():
|
def do_gc():
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
@ -1329,9 +1374,11 @@ def do_gc():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def pad(num, zeroes):
|
def pad(num, zeroes):
|
||||||
return str(num).zfill(zeroes + 1)
|
return str(num).zfill(zeroes + 1)
|
||||||
|
|
||||||
|
|
||||||
def curl(url):
|
def curl(url):
|
||||||
try:
|
try:
|
||||||
req = urllib.request.Request(url, headers={'User-Agent': 'Python'})
|
req = urllib.request.Request(url, headers={'User-Agent': 'Python'})
|
||||||
|
@ -1345,6 +1392,7 @@ def curl(url):
|
||||||
print(e)
|
print(e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def check_for_updates():
|
def check_for_updates():
|
||||||
if not os.path.isfile('./.git/FETCH_HEAD'):
|
if not os.path.isfile('./.git/FETCH_HEAD'):
|
||||||
print("Cannot check for updates: not from a git repo")
|
print("Cannot check for updates: not from a git repo")
|
||||||
|
@ -1379,13 +1427,16 @@ def check_for_updates():
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def enumerate_progress(iterable, desc=None, progress=None, verbose=None):
|
def enumerate_progress(iterable, desc=None, progress=None, verbose=None):
|
||||||
if verbose and desc is not None:
|
if verbose and desc is not None:
|
||||||
print(desc)
|
print(desc)
|
||||||
|
|
||||||
if progress is None:
|
if progress is None:
|
||||||
return tqdm(iterable, disable=not verbose)
|
return tqdm(iterable, disable=not verbose)
|
||||||
return progress.tqdm(iterable, desc=f'{progress.msg_prefix} {desc}' if hasattr(progress, 'msg_prefix') else desc, track_tqdm=True)
|
return progress.tqdm(iterable, desc=f'{progress.msg_prefix} {desc}' if hasattr(progress, 'msg_prefix') else desc,
|
||||||
|
track_tqdm=True)
|
||||||
|
|
||||||
|
|
||||||
def notify_progress(message, progress=None, verbose=True):
|
def notify_progress(message, progress=None, verbose=True):
|
||||||
if verbose:
|
if verbose:
|
||||||
|
@ -1396,10 +1447,12 @@ def notify_progress(message, progress=None, verbose=True):
|
||||||
|
|
||||||
progress(0, desc=message)
|
progress(0, desc=message)
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
global args
|
global args
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def setup_args():
|
def setup_args():
|
||||||
global args
|
global args
|
||||||
|
|
||||||
|
@ -1412,7 +1465,8 @@ def setup_args():
|
||||||
'sample-batch-size': None,
|
'sample-batch-size': None,
|
||||||
'embed-output-metadata': True,
|
'embed-output-metadata': True,
|
||||||
'latents-lean-and-mean': True,
|
'latents-lean-and-mean': True,
|
||||||
'voice-fixer': False, # getting tired of long initialization times in a Colab for downloading a large dataset for it
|
'voice-fixer': False,
|
||||||
|
# getting tired of long initialization times in a Colab for downloading a large dataset for it
|
||||||
'voice-fixer-use-cuda': True,
|
'voice-fixer-use-cuda': True,
|
||||||
'force-cpu-for-conditioning-latents': False,
|
'force-cpu-for-conditioning-latents': False,
|
||||||
'defer-tts-load': False,
|
'defer-tts-load': False,
|
||||||
|
@ -1443,32 +1497,61 @@ def setup_args():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--share", action='store_true', default=default_arguments['share'], help="Lets Gradio return a public URL to use anywhere")
|
parser.add_argument("--share", action='store_true', default=default_arguments['share'],
|
||||||
|
help="Lets Gradio return a public URL to use anywhere")
|
||||||
parser.add_argument("--listen", default=default_arguments['listen'], help="Path for Gradio to listen on")
|
parser.add_argument("--listen", default=default_arguments['listen'], help="Path for Gradio to listen on")
|
||||||
parser.add_argument("--check-for-updates", action='store_true', default=default_arguments['check-for-updates'], help="Checks for update on startup")
|
parser.add_argument("--check-for-updates", action='store_true', default=default_arguments['check-for-updates'],
|
||||||
parser.add_argument("--models-from-local-only", action='store_true', default=default_arguments['models-from-local-only'], help="Only loads models from disk, does not check for updates for models")
|
help="Checks for update on startup")
|
||||||
parser.add_argument("--low-vram", action='store_true', default=default_arguments['low-vram'], help="Disables some optimizations that increases VRAM usage")
|
parser.add_argument("--models-from-local-only", action='store_true',
|
||||||
parser.add_argument("--no-embed-output-metadata", action='store_false', default=not default_arguments['embed-output-metadata'], help="Disables embedding output metadata into resulting WAV files for easily fetching its settings used with the web UI (data is stored in the lyrics metadata tag)")
|
default=default_arguments['models-from-local-only'],
|
||||||
parser.add_argument("--latents-lean-and-mean", action='store_true', default=default_arguments['latents-lean-and-mean'], help="Exports the bare essentials for latents.")
|
help="Only loads models from disk, does not check for updates for models")
|
||||||
parser.add_argument("--voice-fixer", action='store_true', default=default_arguments['voice-fixer'], help="Uses python module 'voicefixer' to improve audio quality, if available.")
|
parser.add_argument("--low-vram", action='store_true', default=default_arguments['low-vram'],
|
||||||
parser.add_argument("--voice-fixer-use-cuda", action='store_true', default=default_arguments['voice-fixer-use-cuda'], help="Hints to voicefixer to use CUDA, if available.")
|
help="Disables some optimizations that increases VRAM usage")
|
||||||
parser.add_argument("--force-cpu-for-conditioning-latents", default=default_arguments['force-cpu-for-conditioning-latents'], action='store_true', help="Forces computing conditional latents to be done on the CPU (if you constantyl OOM on low chunk counts)")
|
parser.add_argument("--no-embed-output-metadata", action='store_false',
|
||||||
parser.add_argument("--defer-tts-load", default=default_arguments['defer-tts-load'], action='store_true', help="Defers loading TTS model")
|
default=not default_arguments['embed-output-metadata'],
|
||||||
parser.add_argument("--prune-nonfinal-outputs", default=default_arguments['prune-nonfinal-outputs'], action='store_true', help="Deletes non-final output files on completing a generation")
|
help="Disables embedding output metadata into resulting WAV files for easily fetching its settings used with the web UI (data is stored in the lyrics metadata tag)")
|
||||||
parser.add_argument("--use-bigvgan-vocoder", default=default_arguments['use-bigvgan-vocoder'], action='store_true', help="Uses BigVGAN in place of the default vocoder")
|
parser.add_argument("--latents-lean-and-mean", action='store_true',
|
||||||
parser.add_argument("--device-override", default=default_arguments['device-override'], help="A device string to override pass through Torch")
|
default=default_arguments['latents-lean-and-mean'],
|
||||||
parser.add_argument("--sample-batch-size", default=default_arguments['sample-batch-size'], type=int, help="Sets how many batches to use during the autoregressive samples pass")
|
help="Exports the bare essentials for latents.")
|
||||||
parser.add_argument("--concurrency-count", type=int, default=default_arguments['concurrency-count'], help="How many Gradio events to process at once")
|
parser.add_argument("--voice-fixer", action='store_true', default=default_arguments['voice-fixer'],
|
||||||
parser.add_argument("--autocalculate-voice-chunk-duration-size", type=float, default=default_arguments['autocalculate-voice-chunk-duration-size'], help="Number of seconds to suggest voice chunk size for (for example, 100 seconds of audio at 10 seconds per chunk will suggest 10 chunks)")
|
help="Uses python module 'voicefixer' to improve audio quality, if available.")
|
||||||
parser.add_argument("--output-sample-rate", type=int, default=default_arguments['output-sample-rate'], help="Sample rate to resample the output to (from 24KHz)")
|
parser.add_argument("--voice-fixer-use-cuda", action='store_true',
|
||||||
parser.add_argument("--output-volume", type=float, default=default_arguments['output-volume'], help="Adjusts volume of output")
|
default=default_arguments['voice-fixer-use-cuda'],
|
||||||
|
help="Hints to voicefixer to use CUDA, if available.")
|
||||||
|
parser.add_argument("--force-cpu-for-conditioning-latents",
|
||||||
|
default=default_arguments['force-cpu-for-conditioning-latents'], action='store_true',
|
||||||
|
help="Forces computing conditional latents to be done on the CPU (if you constantyl OOM on low chunk counts)")
|
||||||
|
parser.add_argument("--defer-tts-load", default=default_arguments['defer-tts-load'], action='store_true',
|
||||||
|
help="Defers loading TTS model")
|
||||||
|
parser.add_argument("--prune-nonfinal-outputs", default=default_arguments['prune-nonfinal-outputs'],
|
||||||
|
action='store_true', help="Deletes non-final output files on completing a generation")
|
||||||
|
parser.add_argument("--use-bigvgan-vocoder", default=default_arguments['use-bigvgan-vocoder'], action='store_true',
|
||||||
|
help="Uses BigVGAN in place of the default vocoder")
|
||||||
|
parser.add_argument("--device-override", default=default_arguments['device-override'],
|
||||||
|
help="A device string to override pass through Torch")
|
||||||
|
parser.add_argument("--sample-batch-size", default=default_arguments['sample-batch-size'], type=int,
|
||||||
|
help="Sets how many batches to use during the autoregressive samples pass")
|
||||||
|
parser.add_argument("--concurrency-count", type=int, default=default_arguments['concurrency-count'],
|
||||||
|
help="How many Gradio events to process at once")
|
||||||
|
parser.add_argument("--autocalculate-voice-chunk-duration-size", type=float,
|
||||||
|
default=default_arguments['autocalculate-voice-chunk-duration-size'],
|
||||||
|
help="Number of seconds to suggest voice chunk size for (for example, 100 seconds of audio at 10 seconds per chunk will suggest 10 chunks)")
|
||||||
|
parser.add_argument("--output-sample-rate", type=int, default=default_arguments['output-sample-rate'],
|
||||||
|
help="Sample rate to resample the output to (from 24KHz)")
|
||||||
|
parser.add_argument("--output-volume", type=float, default=default_arguments['output-volume'],
|
||||||
|
help="Adjusts volume of output")
|
||||||
|
|
||||||
parser.add_argument("--autoregressive-model", default=default_arguments['autoregressive-model'], help="Specifies which autoregressive model to use for sampling.")
|
parser.add_argument("--autoregressive-model", default=default_arguments['autoregressive-model'],
|
||||||
parser.add_argument("--whisper-model", default=default_arguments['whisper-model'], help="Specifies which whisper model to use for transcription.")
|
help="Specifies which autoregressive model to use for sampling.")
|
||||||
parser.add_argument("--whisper-cpp", default=default_arguments['whisper-cpp'], action='store_true', help="Leverages lightmare/whispercpp for transcription")
|
parser.add_argument("--whisper-model", default=default_arguments['whisper-model'],
|
||||||
|
help="Specifies which whisper model to use for transcription.")
|
||||||
|
parser.add_argument("--whisper-cpp", default=default_arguments['whisper-cpp'], action='store_true',
|
||||||
|
help="Leverages lightmare/whispercpp for transcription")
|
||||||
|
|
||||||
parser.add_argument("--training-default-halfp", action='store_true', default=default_arguments['training-default-halfp'], help="Training default: halfp")
|
parser.add_argument("--training-default-halfp", action='store_true',
|
||||||
parser.add_argument("--training-default-bnb", action='store_true', default=default_arguments['training-default-bnb'], help="Training default: bnb")
|
default=default_arguments['training-default-halfp'], help="Training default: halfp")
|
||||||
|
parser.add_argument("--training-default-bnb", action='store_true',
|
||||||
|
default=default_arguments['training-default-bnb'], help="Training default: bnb")
|
||||||
|
|
||||||
parser.add_argument("--os", default="unix", help="Specifies which OS, easily")
|
parser.add_argument("--os", default="unix", help="Specifies which OS, easily")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
@ -1478,7 +1561,6 @@ def setup_args():
|
||||||
if not args.device_override:
|
if not args.device_override:
|
||||||
set_device_name(args.device_override)
|
set_device_name(args.device_override)
|
||||||
|
|
||||||
|
|
||||||
args.listen_host = None
|
args.listen_host = None
|
||||||
args.listen_port = None
|
args.listen_port = None
|
||||||
args.listen_path = None
|
args.listen_path = None
|
||||||
|
@ -1499,7 +1581,12 @@ def setup_args():
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def update_args( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, defer_tts_load, prune_nonfinal_outputs, use_bigvgan_vocoder, device_override, sample_batch_size, concurrency_count, autocalculate_voice_chunk_duration_size, output_volume, autoregressive_model, whisper_model, whisper_cpp, training_default_halfp, training_default_bnb ):
|
|
||||||
|
def update_args(listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata,
|
||||||
|
latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents,
|
||||||
|
defer_tts_load, prune_nonfinal_outputs, use_bigvgan_vocoder, device_override, sample_batch_size,
|
||||||
|
concurrency_count, autocalculate_voice_chunk_duration_size, output_volume, autoregressive_model,
|
||||||
|
whisper_model, whisper_cpp, training_default_halfp, training_default_bnb):
|
||||||
global args
|
global args
|
||||||
|
|
||||||
args.listen = listen
|
args.listen = listen
|
||||||
|
@ -1531,6 +1618,7 @@ def update_args( listen, share, check_for_updates, models_from_local_only, low_v
|
||||||
|
|
||||||
save_args_settings()
|
save_args_settings()
|
||||||
|
|
||||||
|
|
||||||
def save_args_settings():
|
def save_args_settings():
|
||||||
global args
|
global args
|
||||||
settings = {
|
settings = {
|
||||||
|
@ -1567,7 +1655,6 @@ def save_args_settings():
|
||||||
f.write(json.dumps(settings, indent='\t'))
|
f.write(json.dumps(settings, indent='\t'))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def import_generate_settings(file="./config/generate.json"):
|
def import_generate_settings(file="./config/generate.json"):
|
||||||
settings, _ = read_generate_settings(file, read_latents=False)
|
settings, _ = read_generate_settings(file, read_latents=False)
|
||||||
|
|
||||||
|
@ -1604,6 +1691,7 @@ def reset_generation_settings():
|
||||||
f.write(json.dumps({}, indent='\t'))
|
f.write(json.dumps({}, indent='\t'))
|
||||||
return import_generate_settings()
|
return import_generate_settings()
|
||||||
|
|
||||||
|
|
||||||
def read_generate_settings(file, read_latents=True):
|
def read_generate_settings(file, read_latents=True):
|
||||||
j = None
|
j = None
|
||||||
latents = None
|
latents = None
|
||||||
|
@ -1632,17 +1720,15 @@ def read_generate_settings(file, read_latents=True):
|
||||||
latents = base64.b64decode(j['latents'])
|
latents = base64.b64decode(j['latents'])
|
||||||
del j['latents']
|
del j['latents']
|
||||||
|
|
||||||
|
|
||||||
if "time" in j:
|
if "time" in j:
|
||||||
j["time"] = "{:.3f}".format(j["time"])
|
j["time"] = "{:.3f}".format(j["time"])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
j,
|
j,
|
||||||
latents,
|
latents,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_tts(restart=False, model=None):
|
def load_tts(restart=False, model=None):
|
||||||
global args
|
global args
|
||||||
global tts
|
global tts
|
||||||
|
@ -1650,7 +1736,6 @@ def load_tts( restart=False, model=None ):
|
||||||
if restart:
|
if restart:
|
||||||
unload_tts()
|
unload_tts()
|
||||||
|
|
||||||
|
|
||||||
if model:
|
if model:
|
||||||
args.autoregressive_model = model
|
args.autoregressive_model = model
|
||||||
|
|
||||||
|
@ -1672,8 +1757,10 @@ def load_tts( restart=False, model=None ):
|
||||||
print("Loaded TorToiSe, ready for generation.")
|
print("Loaded TorToiSe, ready for generation.")
|
||||||
return tts
|
return tts
|
||||||
|
|
||||||
|
|
||||||
setup_tortoise = load_tts
|
setup_tortoise = load_tts
|
||||||
|
|
||||||
|
|
||||||
def unload_tts():
|
def unload_tts():
|
||||||
global tts
|
global tts
|
||||||
|
|
||||||
|
@ -1683,9 +1770,11 @@ def unload_tts():
|
||||||
print("Unloaded TTS")
|
print("Unloaded TTS")
|
||||||
do_gc()
|
do_gc()
|
||||||
|
|
||||||
|
|
||||||
def reload_tts(model=None):
|
def reload_tts(model=None):
|
||||||
load_tts(restart=True, model=model)
|
load_tts(restart=True, model=model)
|
||||||
|
|
||||||
|
|
||||||
def update_autoregressive_model(autoregressive_model_path):
|
def update_autoregressive_model(autoregressive_model_path):
|
||||||
match = re.findall(r'^\[[a-fA-F0-9]{8}\] (.+?)$', autoregressive_model_path)
|
match = re.findall(r'^\[[a-fA-F0-9]{8}\] (.+?)$', autoregressive_model_path)
|
||||||
if match:
|
if match:
|
||||||
|
@ -1714,7 +1803,8 @@ def update_autoregressive_model(autoregressive_model_path):
|
||||||
else:
|
else:
|
||||||
from tortoise.models.autoregressive import UnifiedVoice
|
from tortoise.models.autoregressive import UnifiedVoice
|
||||||
|
|
||||||
tts.autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os.path.exists(autoregressive_model_path) else get_model_path('autoregressive.pth', tts.models_dir)
|
tts.autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os.path.exists(
|
||||||
|
autoregressive_model_path) else get_model_path('autoregressive.pth', tts.models_dir)
|
||||||
|
|
||||||
del tts.autoregressive
|
del tts.autoregressive
|
||||||
tts.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
|
tts.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
|
||||||
|
@ -1735,6 +1825,7 @@ def update_autoregressive_model(autoregressive_model_path):
|
||||||
|
|
||||||
return autoregressive_model_path
|
return autoregressive_model_path
|
||||||
|
|
||||||
|
|
||||||
def load_voicefixer(restart=False):
|
def load_voicefixer(restart=False):
|
||||||
global voicefixer
|
global voicefixer
|
||||||
|
|
||||||
|
@ -1749,6 +1840,7 @@ def load_voicefixer(restart=False):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error occurred while tring to initialize voicefixer: {e}")
|
print(f"Error occurred while tring to initialize voicefixer: {e}")
|
||||||
|
|
||||||
|
|
||||||
def unload_voicefixer():
|
def unload_voicefixer():
|
||||||
global voicefixer
|
global voicefixer
|
||||||
|
|
||||||
|
@ -1759,7 +1851,8 @@ def unload_voicefixer():
|
||||||
|
|
||||||
do_gc()
|
do_gc()
|
||||||
|
|
||||||
def load_whisper_model(language=None, model_name=None, progress=None):
|
|
||||||
|
def load_whisper_model(model_name=None, progress=None):
|
||||||
global whisper_model
|
global whisper_model
|
||||||
|
|
||||||
if not model_name:
|
if not model_name:
|
||||||
|
@ -1768,24 +1861,17 @@ def load_whisper_model(language=None, model_name=None, progress=None):
|
||||||
args.whisper_model = model_name
|
args.whisper_model = model_name
|
||||||
save_args_settings()
|
save_args_settings()
|
||||||
|
|
||||||
if language and f'{model_name}.{language}' in WHISPER_SPECIALIZED_MODELS:
|
if (torch.cuda.is_available()):
|
||||||
model_name = f'{model_name}.{language}'
|
device = "cuda"
|
||||||
print(f"Loading specialized model for language: {language}")
|
|
||||||
|
|
||||||
notify_progress(f"Loading Whisper model: {model_name}", progress)
|
|
||||||
|
|
||||||
if args.whisper_cpp:
|
|
||||||
from whispercpp import Whisper
|
|
||||||
if not language:
|
|
||||||
language = 'auto'
|
|
||||||
|
|
||||||
b_lang = language.encode('ascii')
|
|
||||||
whisper_model = Whisper(model_name, models_dir='./models/', language=b_lang)
|
|
||||||
else:
|
else:
|
||||||
import whisper
|
device = "cpu"
|
||||||
whisper_model = whisper.load_model(model_name)
|
|
||||||
|
notify_progress(f"Loading WhisperX model: {model_name} using {device}", progress)
|
||||||
|
|
||||||
|
whisper_model = whisperx.load_model(model_name, device)
|
||||||
|
|
||||||
|
print("Loaded WhisperX model")
|
||||||
|
|
||||||
print("Loaded Whisper model")
|
|
||||||
|
|
||||||
def unload_whisper():
|
def unload_whisper():
|
||||||
global whisper_model
|
global whisper_model
|
||||||
|
@ -1793,6 +1879,6 @@ def unload_whisper():
|
||||||
if whisper_model:
|
if whisper_model:
|
||||||
del whisper_model
|
del whisper_model
|
||||||
whisper_model = None
|
whisper_model = None
|
||||||
print("Unloaded Whisper")
|
print("Unloaded WhisperX")
|
||||||
|
|
||||||
do_gc()
|
do_gc()
|
|
@ -590,9 +590,8 @@ def setup_gradio():
|
||||||
autoregressive_model_dropdown = gr.Dropdown(choices=autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else autoregressive_models[0])
|
autoregressive_model_dropdown = gr.Dropdown(choices=autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else autoregressive_models[0])
|
||||||
|
|
||||||
whisper_model_dropdown = gr.Dropdown(WHISPER_MODELS, label="Whisper Model", value=args.whisper_model)
|
whisper_model_dropdown = gr.Dropdown(WHISPER_MODELS, label="Whisper Model", value=args.whisper_model)
|
||||||
use_whisper_cpp = gr.Checkbox(label="Use Whisper.cpp", value=args.whisper_cpp)
|
|
||||||
|
|
||||||
exec_inputs = exec_inputs + [ autoregressive_model_dropdown, whisper_model_dropdown, use_whisper_cpp, training_halfp, training_bnb ]
|
exec_inputs = exec_inputs + [ autoregressive_model_dropdown, whisper_model_dropdown, training_halfp, training_bnb ]
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
autoregressive_models_update_button = gr.Button(value="Refresh Model List")
|
autoregressive_models_update_button = gr.Button(value="Refresh Model List")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user