made initialization faster if there's a lot of voice files (because glob fucking sucks), commiting changes buried on my training rig
This commit is contained in:
parent
91a0c495ff
commit
72a38ff2fc
|
@ -1 +1 @@
|
||||||
Subproject commit 5ff00bf3bfa97e2c8e9f166b920273f83ac9d8f0
|
Subproject commit cbd3c95c42ac1da9772f61b9895954ee693075c9
|
|
@ -7,4 +7,5 @@ music-tag
|
||||||
voicefixer
|
voicefixer
|
||||||
psutil
|
psutil
|
||||||
phonemizer
|
phonemizer
|
||||||
pydantic==1.10.11
|
pydantic==1.10.11
|
||||||
|
websockets
|
314
src/utils.py
314
src/utils.py
|
@ -45,7 +45,7 @@ from tortoise.utils.device import get_device_name, set_device_name, get_device_c
|
||||||
|
|
||||||
MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"
|
MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"
|
||||||
|
|
||||||
WHISPER_MODELS = ["tiny", "base", "small", "medium", "large"]
|
WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
|
||||||
WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"]
|
WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"]
|
||||||
WHISPER_BACKENDS = ["openai/whisper", "lightmare/whispercpp", "m-bain/whisperx"]
|
WHISPER_BACKENDS = ["openai/whisper", "lightmare/whispercpp", "m-bain/whisperx"]
|
||||||
VOCODERS = ['univnet', 'bigvgan_base_24khz_100band', 'bigvgan_24khz_100band']
|
VOCODERS = ['univnet', 'bigvgan_base_24khz_100band', 'bigvgan_24khz_100band']
|
||||||
|
@ -61,12 +61,15 @@ RESAMPLERS = {}
|
||||||
|
|
||||||
MIN_TRAINING_DURATION = 0.6
|
MIN_TRAINING_DURATION = 0.6
|
||||||
MAX_TRAINING_DURATION = 11.6097505669
|
MAX_TRAINING_DURATION = 11.6097505669
|
||||||
|
MAX_TRAINING_CHAR_LENGTH = 200
|
||||||
|
|
||||||
VALLE_ENABLED = False
|
VALLE_ENABLED = False
|
||||||
BARK_ENABLED = False
|
BARK_ENABLED = False
|
||||||
|
|
||||||
VERBOSE_DEBUG = True
|
VERBOSE_DEBUG = True
|
||||||
|
|
||||||
|
import traceback
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from whisper.normalizers.english import EnglishTextNormalizer
|
from whisper.normalizers.english import EnglishTextNormalizer
|
||||||
from whisper.normalizers.basic import BasicTextNormalizer
|
from whisper.normalizers.basic import BasicTextNormalizer
|
||||||
|
@ -75,7 +78,7 @@ try:
|
||||||
print("Whisper detected")
|
print("Whisper detected")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if VERBOSE_DEBUG:
|
if VERBOSE_DEBUG:
|
||||||
print("Error:", e)
|
print(traceback.format_exc())
|
||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -90,12 +93,14 @@ try:
|
||||||
VALLE_ENABLED = True
|
VALLE_ENABLED = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if VERBOSE_DEBUG:
|
if VERBOSE_DEBUG:
|
||||||
print("Error:", e)
|
print(traceback.format_exc())
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if VALLE_ENABLED:
|
if VALLE_ENABLED:
|
||||||
TTSES.append('vall-e')
|
TTSES.append('vall-e')
|
||||||
|
|
||||||
|
# torchaudio.set_audio_backend('soundfile')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import bark
|
import bark
|
||||||
from bark import text_to_semantic
|
from bark import text_to_semantic
|
||||||
|
@ -109,35 +114,10 @@ try:
|
||||||
BARK_ENABLED = True
|
BARK_ENABLED = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if VERBOSE_DEBUG:
|
if VERBOSE_DEBUG:
|
||||||
print("Error:", e)
|
print(traceback.format_exc())
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if BARK_ENABLED:
|
if BARK_ENABLED:
|
||||||
try:
|
|
||||||
from vocos import Vocos
|
|
||||||
VOCOS_ENABLED = True
|
|
||||||
print("Vocos detected")
|
|
||||||
except Exception as e:
|
|
||||||
if VERBOSE_DEBUG:
|
|
||||||
print("Error:", e)
|
|
||||||
VOCOS_ENABLED = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
from hubert.hubert_manager import HuBERTManager
|
|
||||||
from hubert.pre_kmeans_hubert import CustomHubert
|
|
||||||
from hubert.customtokenizer import CustomTokenizer
|
|
||||||
|
|
||||||
hubert_manager = HuBERTManager()
|
|
||||||
hubert_manager.make_sure_hubert_installed()
|
|
||||||
hubert_manager.make_sure_tokenizer_installed()
|
|
||||||
|
|
||||||
HUBERT_ENABLED = True
|
|
||||||
print("HuBERT detected")
|
|
||||||
except Exception as e:
|
|
||||||
if VERBOSE_DEBUG:
|
|
||||||
print("Error:", e)
|
|
||||||
HUBERT_ENABLED = False
|
|
||||||
|
|
||||||
TTSES.append('bark')
|
TTSES.append('bark')
|
||||||
|
|
||||||
def semantic_to_audio_tokens(
|
def semantic_to_audio_tokens(
|
||||||
|
@ -181,7 +161,32 @@ if BARK_ENABLED:
|
||||||
|
|
||||||
self.device = get_device_name()
|
self.device = get_device_name()
|
||||||
|
|
||||||
if VOCOS_ENABLED:
|
try:
|
||||||
|
from vocos import Vocos
|
||||||
|
self.vocos_enabled = True
|
||||||
|
print("Vocos detected")
|
||||||
|
except Exception as e:
|
||||||
|
if VERBOSE_DEBUG:
|
||||||
|
print(traceback.format_exc())
|
||||||
|
self.vocos_enabled = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from hubert.hubert_manager import HuBERTManager
|
||||||
|
from hubert.pre_kmeans_hubert import CustomHubert
|
||||||
|
from hubert.customtokenizer import CustomTokenizer
|
||||||
|
|
||||||
|
hubert_manager = HuBERTManager()
|
||||||
|
hubert_manager.make_sure_hubert_installed()
|
||||||
|
hubert_manager.make_sure_tokenizer_installed()
|
||||||
|
|
||||||
|
self.hubert_enabled = True
|
||||||
|
print("HuBERT detected")
|
||||||
|
except Exception as e:
|
||||||
|
if VERBOSE_DEBUG:
|
||||||
|
print(traceback.format_exc())
|
||||||
|
self.hubert_enabled = False
|
||||||
|
|
||||||
|
if self.vocos_enabled:
|
||||||
self.vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz").to(self.device)
|
self.vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz").to(self.device)
|
||||||
|
|
||||||
def create_voice( self, voice ):
|
def create_voice( self, voice ):
|
||||||
|
@ -238,7 +243,7 @@ if BARK_ENABLED:
|
||||||
|
|
||||||
# generate semantic tokens
|
# generate semantic tokens
|
||||||
|
|
||||||
if HUBERT_ENABLED:
|
if self.hubert_enabled:
|
||||||
wav = wav.to(self.device)
|
wav = wav.to(self.device)
|
||||||
|
|
||||||
# Extract discrete codes from EnCodec
|
# Extract discrete codes from EnCodec
|
||||||
|
@ -426,7 +431,7 @@ def generate_bark(**kwargs):
|
||||||
idx_cache = {}
|
idx_cache = {}
|
||||||
for i, file in enumerate(os.listdir(outdir)):
|
for i, file in enumerate(os.listdir(outdir)):
|
||||||
filename = os.path.basename(file)
|
filename = os.path.basename(file)
|
||||||
extension = os.path.splitext(filename)[1]
|
extension = os.path.splitext(filename)[-1][1:]
|
||||||
if extension != ".json" and extension != ".wav":
|
if extension != ".json" and extension != ".wav":
|
||||||
continue
|
continue
|
||||||
match = re.findall(rf"^{cleanup_voice_name(voice)}_(\d+)(?:.+?)?{extension}$", filename)
|
match = re.findall(rf"^{cleanup_voice_name(voice)}_(\d+)(?:.+?)?{extension}$", filename)
|
||||||
|
@ -672,18 +677,23 @@ def generate_valle(**kwargs):
|
||||||
|
|
||||||
voice_cache = {}
|
voice_cache = {}
|
||||||
def fetch_voice( voice ):
|
def fetch_voice( voice ):
|
||||||
|
if voice in voice_cache:
|
||||||
|
return voice_cache[voice]
|
||||||
voice_dir = f'./training/{voice}/audio/'
|
voice_dir = f'./training/{voice}/audio/'
|
||||||
if not os.path.isdir(voice_dir):
|
|
||||||
|
if not os.path.isdir(voice_dir) or len(os.listdir(voice_dir)) == 0:
|
||||||
voice_dir = f'./voices/{voice}/'
|
voice_dir = f'./voices/{voice}/'
|
||||||
|
|
||||||
files = [ f'{voice_dir}/{d}' for d in os.listdir(voice_dir) if d[-4:] == ".wav" ]
|
files = [ f'{voice_dir}/{d}' for d in os.listdir(voice_dir) if d[-4:] == ".wav" ]
|
||||||
# return files
|
# return files
|
||||||
return random.choice(files)
|
voice_cache[voice] = random.choice(files)
|
||||||
|
return voice_cache[voice]
|
||||||
|
|
||||||
def get_settings( override=None ):
|
def get_settings( override=None ):
|
||||||
settings = {
|
settings = {
|
||||||
'ar_temp': float(parameters['temperature']),
|
'ar_temp': float(parameters['temperature']),
|
||||||
'nar_temp': float(parameters['temperature']),
|
'nar_temp': float(parameters['temperature']),
|
||||||
'max_ar_samples': parameters['num_autoregressive_samples'],
|
'max_ar_steps': parameters['num_autoregressive_samples'],
|
||||||
}
|
}
|
||||||
|
|
||||||
# could be better to just do a ternary on everything above, but i am not a professional
|
# could be better to just do a ternary on everything above, but i am not a professional
|
||||||
|
@ -697,7 +707,7 @@ def generate_valle(**kwargs):
|
||||||
continue
|
continue
|
||||||
settings[k] = override[k]
|
settings[k] = override[k]
|
||||||
|
|
||||||
settings['reference'] = fetch_voice(voice=selected_voice)
|
settings['references'] = [ fetch_voice(voice=selected_voice) for _ in range(3) ]
|
||||||
return settings
|
return settings
|
||||||
|
|
||||||
if not parameters['delimiter']:
|
if not parameters['delimiter']:
|
||||||
|
@ -723,7 +733,7 @@ def generate_valle(**kwargs):
|
||||||
idx_cache = {}
|
idx_cache = {}
|
||||||
for i, file in enumerate(os.listdir(outdir)):
|
for i, file in enumerate(os.listdir(outdir)):
|
||||||
filename = os.path.basename(file)
|
filename = os.path.basename(file)
|
||||||
extension = os.path.splitext(filename)[1]
|
extension = os.path.splitext(filename)[-1][1:]
|
||||||
if extension != ".json" and extension != ".wav":
|
if extension != ".json" and extension != ".wav":
|
||||||
continue
|
continue
|
||||||
match = re.findall(rf"^{voice}_(\d+)(?:.+?)?{extension}$", filename)
|
match = re.findall(rf"^{voice}_(\d+)(?:.+?)?{extension}$", filename)
|
||||||
|
@ -783,11 +793,14 @@ def generate_valle(**kwargs):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception("Prompt settings editing requested, but received invalid JSON")
|
raise Exception("Prompt settings editing requested, but received invalid JSON")
|
||||||
|
|
||||||
settings = get_settings( override=override )
|
name = get_name(line=line, candidate=0)
|
||||||
reference = settings['reference']
|
|
||||||
settings.pop("reference")
|
|
||||||
|
|
||||||
gen = tts.inference(cut_text, reference, **settings )
|
settings = get_settings( override=override )
|
||||||
|
references = settings['references']
|
||||||
|
settings.pop("references")
|
||||||
|
settings['out_path'] = f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav'
|
||||||
|
|
||||||
|
gen = tts.inference(cut_text, references, **settings )
|
||||||
|
|
||||||
run_time = time.time()-start_time
|
run_time = time.time()-start_time
|
||||||
print(f"Generating line took {run_time} seconds")
|
print(f"Generating line took {run_time} seconds")
|
||||||
|
@ -805,7 +818,7 @@ def generate_valle(**kwargs):
|
||||||
|
|
||||||
# save here in case some error happens mid-batch
|
# save here in case some error happens mid-batch
|
||||||
#torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr)
|
#torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr)
|
||||||
soundfile.write(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu()[0,0], sr)
|
#soundfile.write(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu()[0,0], sr)
|
||||||
wav, sr = torchaudio.load(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
|
wav, sr = torchaudio.load(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
|
||||||
|
|
||||||
audio_cache[name] = {
|
audio_cache[name] = {
|
||||||
|
@ -1085,7 +1098,7 @@ def generate_tortoise(**kwargs):
|
||||||
idx_cache = {}
|
idx_cache = {}
|
||||||
for i, file in enumerate(os.listdir(outdir)):
|
for i, file in enumerate(os.listdir(outdir)):
|
||||||
filename = os.path.basename(file)
|
filename = os.path.basename(file)
|
||||||
extension = os.path.splitext(filename)[1]
|
extension = os.path.splitext(filename)[-1][1:]
|
||||||
if extension != ".json" and extension != ".wav":
|
if extension != ".json" and extension != ".wav":
|
||||||
continue
|
continue
|
||||||
match = re.findall(rf"^{voice}_(\d+)(?:.+?)?{extension}$", filename)
|
match = re.findall(rf"^{voice}_(\d+)(?:.+?)?{extension}$", filename)
|
||||||
|
@ -1605,30 +1618,18 @@ class TrainingState():
|
||||||
if args.tts_backend == "vall-e":
|
if args.tts_backend == "vall-e":
|
||||||
keys['lrs'] = [
|
keys['lrs'] = [
|
||||||
'ar.lr', 'nar.lr',
|
'ar.lr', 'nar.lr',
|
||||||
'ar-half.lr', 'nar-half.lr',
|
|
||||||
'ar-quarter.lr', 'nar-quarter.lr',
|
|
||||||
]
|
]
|
||||||
keys['losses'] = [
|
keys['losses'] = [
|
||||||
'ar.loss', 'nar.loss', 'ar+nar.loss',
|
# 'ar.loss', 'nar.loss', 'ar+nar.loss',
|
||||||
'ar-half.loss', 'nar-half.loss', 'ar-half+nar-half.loss',
|
|
||||||
'ar-quarter.loss', 'nar-quarter.loss', 'ar-quarter+nar-quarter.loss',
|
|
||||||
|
|
||||||
'ar.loss.nll', 'nar.loss.nll',
|
'ar.loss.nll', 'nar.loss.nll',
|
||||||
'ar-half.loss.nll', 'nar-half.loss.nll',
|
|
||||||
'ar-quarter.loss.nll', 'nar-quarter.loss.nll',
|
|
||||||
]
|
]
|
||||||
|
|
||||||
keys['accuracies'] = [
|
keys['accuracies'] = [
|
||||||
'ar.loss.acc', 'nar.loss.acc',
|
'ar.loss.acc', 'nar.loss.acc',
|
||||||
'ar-half.loss.acc', 'nar-half.loss.acc',
|
'ar.stats.acc', 'nar.loss.acc',
|
||||||
'ar-quarter.loss.acc', 'nar-quarter.loss.acc',
|
|
||||||
]
|
]
|
||||||
keys['precisions'] = [
|
keys['precisions'] = [ 'ar.loss.precision', 'nar.loss.precision', ]
|
||||||
'ar.loss.precision', 'nar.loss.precision',
|
keys['grad_norms'] = ['ar.grad_norm', 'nar.grad_norm']
|
||||||
'ar-half.loss.precision', 'nar-half.loss.precision',
|
|
||||||
'ar-quarter.loss.precision', 'nar-quarter.loss.precision',
|
|
||||||
]
|
|
||||||
keys['grad_norms'] = ['ar.grad_norm', 'nar.grad_norm', 'ar-half.grad_norm', 'nar-half.grad_norm', 'ar-quarter.grad_norm', 'nar-quarter.grad_norm']
|
|
||||||
|
|
||||||
for k in keys['lrs']:
|
for k in keys['lrs']:
|
||||||
if k not in self.info:
|
if k not in self.info:
|
||||||
|
@ -1746,7 +1747,8 @@ class TrainingState():
|
||||||
if args.tts_backend == "tortoise":
|
if args.tts_backend == "tortoise":
|
||||||
logs = sorted([f'{self.training_dir}/finetune/{d}' for d in os.listdir(f'{self.training_dir}/finetune/') if d[-4:] == ".log" ])
|
logs = sorted([f'{self.training_dir}/finetune/{d}' for d in os.listdir(f'{self.training_dir}/finetune/') if d[-4:] == ".log" ])
|
||||||
else:
|
else:
|
||||||
logs = sorted([f'{self.training_dir}/logs/{d}/log.txt' for d in os.listdir(f'{self.training_dir}/logs/') ])
|
log_dir = "logs"
|
||||||
|
logs = sorted([f'{self.training_dir}/{log_dir}/{d}/log.txt' for d in os.listdir(f'{self.training_dir}/{log_dir}/') ])
|
||||||
|
|
||||||
if update:
|
if update:
|
||||||
logs = [logs[-1]]
|
logs = [logs[-1]]
|
||||||
|
@ -2219,6 +2221,8 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
|
||||||
files = get_voice(voice, load_latents=False)
|
files = get_voice(voice, load_latents=False)
|
||||||
indir = f'./training/{voice}/'
|
indir = f'./training/{voice}/'
|
||||||
infile = f'{indir}/whisper.json'
|
infile = f'{indir}/whisper.json'
|
||||||
|
|
||||||
|
quantize_in_memory = args.tts_backend == "vall-e"
|
||||||
|
|
||||||
os.makedirs(f'{indir}/audio/', exist_ok=True)
|
os.makedirs(f'{indir}/audio/', exist_ok=True)
|
||||||
|
|
||||||
|
@ -2245,13 +2249,24 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
|
||||||
continue
|
continue
|
||||||
|
|
||||||
results[basename] = result
|
results[basename] = result
|
||||||
waveform, sample_rate = torchaudio.load(file)
|
|
||||||
# resample to the input rate, since it'll get resampled for training anyways
|
if not quantize_in_memory:
|
||||||
# this should also "help" increase throughput a bit when filling the dataloaders
|
waveform, sample_rate = torchaudio.load(file)
|
||||||
waveform, sample_rate = resample(waveform, sample_rate, TARGET_SAMPLE_RATE)
|
# resample to the input rate, since it'll get resampled for training anyways
|
||||||
if waveform.shape[0] == 2:
|
# this should also "help" increase throughput a bit when filling the dataloaders
|
||||||
waveform = waveform[:1]
|
waveform, sample_rate = resample(waveform, sample_rate, TARGET_SAMPLE_RATE)
|
||||||
torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate, encoding="PCM_S", bits_per_sample=16)
|
if waveform.shape[0] == 2:
|
||||||
|
waveform = waveform[:1]
|
||||||
|
|
||||||
|
try:
|
||||||
|
kwargs = {}
|
||||||
|
if basename[-4:] == ".wav":
|
||||||
|
kwargs['encoding'] = "PCM_S"
|
||||||
|
kwargs['bits_per_sample'] = 16
|
||||||
|
|
||||||
|
torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
with open(infile, 'w', encoding="utf-8") as f:
|
with open(infile, 'w', encoding="utf-8") as f:
|
||||||
f.write(json.dumps(results, indent='\t'))
|
f.write(json.dumps(results, indent='\t'))
|
||||||
|
@ -2317,6 +2332,9 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul
|
||||||
segments = 0
|
segments = 0
|
||||||
for filename in results:
|
for filename in results:
|
||||||
path = f'./voices/{voice}/{filename}'
|
path = f'./voices/{voice}/{filename}'
|
||||||
|
extension = os.path.splitext(filename)[-1][1:]
|
||||||
|
out_extension = extension # "wav"
|
||||||
|
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
path = f'./training/{voice}/{filename}'
|
path = f'./training/{voice}/{filename}'
|
||||||
|
|
||||||
|
@ -2333,7 +2351,7 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul
|
||||||
duration = num_frames / sample_rate
|
duration = num_frames / sample_rate
|
||||||
|
|
||||||
for segment in result['segments']:
|
for segment in result['segments']:
|
||||||
file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")
|
file = filename.replace(f".{extension}", f"_{pad(segment['id'], 4)}.{out_extension}")
|
||||||
|
|
||||||
sliced, error = slice_waveform( waveform, sample_rate, segment['start'] + start_offset, segment['end'] + end_offset, trim_silence )
|
sliced, error = slice_waveform( waveform, sample_rate, segment['start'] + start_offset, segment['end'] + end_offset, trim_silence )
|
||||||
if error:
|
if error:
|
||||||
|
@ -2341,12 +2359,17 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul
|
||||||
print(message)
|
print(message)
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
continue
|
continue
|
||||||
sliced, _ = resample( sliced, sample_rate, TARGET_SAMPLE_RATE )
|
# sliced, _ = resample( sliced, sample_rate, TARGET_SAMPLE_RATE )
|
||||||
|
|
||||||
if waveform.shape[0] == 2:
|
if waveform.shape[0] == 2:
|
||||||
waveform = waveform[:1]
|
waveform = waveform[:1]
|
||||||
|
|
||||||
torchaudio.save(f"{indir}/audio/{file}", sliced, TARGET_SAMPLE_RATE, encoding="PCM_S", bits_per_sample=16)
|
kwargs = {}
|
||||||
|
if file[-4:] == ".wav":
|
||||||
|
kwargs['encoding'] = "PCM_S"
|
||||||
|
kwargs['bits_per_sample'] = 16
|
||||||
|
|
||||||
|
torchaudio.save(f"{indir}/audio/{file}", sliced, TARGET_SAMPLE_RATE, **kwargs)
|
||||||
|
|
||||||
segments +=1
|
segments +=1
|
||||||
|
|
||||||
|
@ -2462,18 +2485,32 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
||||||
|
|
||||||
errored = 0
|
errored = 0
|
||||||
messages = []
|
messages = []
|
||||||
normalize = True
|
normalize = False # True
|
||||||
phonemize = should_phonemize()
|
phonemize = should_phonemize()
|
||||||
lines = { 'training': [], 'validation': [] }
|
lines = { 'training': [], 'validation': [] }
|
||||||
segments = {}
|
segments = {}
|
||||||
|
|
||||||
|
quantize_in_memory = args.tts_backend == "vall-e"
|
||||||
|
|
||||||
if args.tts_backend != "tortoise":
|
if args.tts_backend != "tortoise":
|
||||||
text_length = 0
|
text_length = 0
|
||||||
audio_length = 0
|
audio_length = 0
|
||||||
|
|
||||||
|
start_offset = -0.1
|
||||||
|
end_offset = 0.1
|
||||||
|
trim_silence = False
|
||||||
|
|
||||||
|
TARGET_SAMPLE_RATE = 22050
|
||||||
|
if args.tts_backend != "tortoise":
|
||||||
|
TARGET_SAMPLE_RATE = 24000
|
||||||
|
if tts:
|
||||||
|
TARGET_SAMPLE_RATE = tts.input_sample_rate
|
||||||
|
|
||||||
for filename in tqdm(results, desc="Parsing results"):
|
for filename in tqdm(results, desc="Parsing results"):
|
||||||
use_segment = use_segments
|
use_segment = use_segments
|
||||||
|
|
||||||
|
extension = os.path.splitext(filename)[-1][1:]
|
||||||
|
out_extension = extension # "wav"
|
||||||
result = results[filename]
|
result = results[filename]
|
||||||
lang = result['language']
|
lang = result['language']
|
||||||
language = LANGUAGES[lang] if lang in LANGUAGES else lang
|
language = LANGUAGES[lang] if lang in LANGUAGES else lang
|
||||||
|
@ -2481,8 +2518,8 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
||||||
|
|
||||||
# check if unsegmented text exceeds 200 characters
|
# check if unsegmented text exceeds 200 characters
|
||||||
if not use_segment:
|
if not use_segment:
|
||||||
if len(result['text']) > 200:
|
if len(result['text']) > MAX_TRAINING_CHAR_LENGTH:
|
||||||
message = f"Text length too long (200 < {len(result['text'])}), using segments: {filename}"
|
message = f"Text length too long ({MAX_TRAINING_CHAR_LENGTH} < {len(result['text'])}), using segments: {filename}"
|
||||||
print(message)
|
print(message)
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
use_segment = True
|
use_segment = True
|
||||||
|
@ -2490,13 +2527,15 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
||||||
# check if unsegmented audio exceeds 11.6s
|
# check if unsegmented audio exceeds 11.6s
|
||||||
if not use_segment:
|
if not use_segment:
|
||||||
path = f'{indir}/audio/{filename}'
|
path = f'{indir}/audio/{filename}'
|
||||||
if not os.path.exists(path):
|
if not quantize_in_memory and not os.path.exists(path):
|
||||||
messages.append(f"Missing source audio: {filename}")
|
messages.append(f"Missing source audio: {filename}")
|
||||||
errored += 1
|
errored += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
metadata = torchaudio.info(path)
|
duration = 0
|
||||||
duration = metadata.num_frames / metadata.sample_rate
|
for segment in result['segments']:
|
||||||
|
duration = max(duration, result['segments'][segment]['end'])
|
||||||
|
|
||||||
if duration >= MAX_TRAINING_DURATION:
|
if duration >= MAX_TRAINING_DURATION:
|
||||||
message = f"Audio too large, using segments: {filename}"
|
message = f"Audio too large, using segments: {filename}"
|
||||||
print(message)
|
print(message)
|
||||||
|
@ -2511,19 +2550,36 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
||||||
if duration <= MIN_TRAINING_DURATION or MAX_TRAINING_DURATION <= duration:
|
if duration <= MIN_TRAINING_DURATION or MAX_TRAINING_DURATION <= duration:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
path = f'{indir}/audio/' + filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")
|
path = f'{indir}/audio/' + filename.replace(f".{extension}", f"_{pad(segment['id'], 4)}.{out_extension}")
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
continue
|
continue
|
||||||
exists = False
|
exists = False
|
||||||
break
|
break
|
||||||
|
|
||||||
if not exists:
|
if not quantize_in_memory and not exists:
|
||||||
tmp = {}
|
tmp = {}
|
||||||
tmp[filename] = result
|
tmp[filename] = result
|
||||||
print(f"Audio not segmented, segmenting: {filename}")
|
print(f"Audio not segmented, segmenting: {filename}")
|
||||||
message = slice_dataset( voice, results=tmp )
|
message = slice_dataset( voice, results=tmp )
|
||||||
print(message)
|
print(message)
|
||||||
messages = messages + message.split("\n")
|
messages = messages + message.split("\n")
|
||||||
|
|
||||||
|
waveform = None
|
||||||
|
|
||||||
|
|
||||||
|
if quantize_in_memory:
|
||||||
|
path = f'{indir}/audio/{filename}'
|
||||||
|
if not os.path.exists(path):
|
||||||
|
path = f'./voices/{voice}/{filename}'
|
||||||
|
|
||||||
|
if not os.path.exists(path):
|
||||||
|
message = f"Audio not found: {path}"
|
||||||
|
print(message)
|
||||||
|
messages.append(message)
|
||||||
|
#continue
|
||||||
|
else:
|
||||||
|
waveform = torchaudio.load(path)
|
||||||
|
waveform = resample(waveform[0], waveform[1], TARGET_SAMPLE_RATE)
|
||||||
|
|
||||||
if not use_segment:
|
if not use_segment:
|
||||||
segments[filename] = {
|
segments[filename] = {
|
||||||
|
@ -2533,13 +2589,18 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
||||||
'normalizer': normalizer,
|
'normalizer': normalizer,
|
||||||
'phonemes': result['phonemes'] if 'phonemes' in result else None
|
'phonemes': result['phonemes'] if 'phonemes' in result else None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if waveform:
|
||||||
|
segments[filename]['waveform'] = waveform
|
||||||
else:
|
else:
|
||||||
for segment in result['segments']:
|
for segment in result['segments']:
|
||||||
duration = segment['end'] - segment['start']
|
duration = segment['end'] - segment['start']
|
||||||
if duration <= MIN_TRAINING_DURATION or MAX_TRAINING_DURATION <= duration:
|
if duration <= MIN_TRAINING_DURATION or MAX_TRAINING_DURATION <= duration:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
segments[filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")] = {
|
file = filename.replace(f".{extension}", f"_{pad(segment['id'], 4)}.{out_extension}")
|
||||||
|
|
||||||
|
segments[file] = {
|
||||||
'text': segment['text'],
|
'text': segment['text'],
|
||||||
'lang': lang,
|
'lang': lang,
|
||||||
'language': language,
|
'language': language,
|
||||||
|
@ -2547,22 +2608,27 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
||||||
'phonemes': segment['phonemes'] if 'phonemes' in segment else None
|
'phonemes': segment['phonemes'] if 'phonemes' in segment else None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if waveform:
|
||||||
|
sliced, error = slice_waveform( waveform[0], waveform[1], segment['start'] + start_offset, segment['end'] + end_offset, trim_silence )
|
||||||
|
if error:
|
||||||
|
message = f"{error}, skipping... {file}"
|
||||||
|
print(message)
|
||||||
|
messages.append(message)
|
||||||
|
segments[file]['error'] = error
|
||||||
|
#continue
|
||||||
|
else:
|
||||||
|
segments[file]['waveform'] = (sliced, waveform[1])
|
||||||
|
|
||||||
jobs = {
|
jobs = {
|
||||||
'quantize': [[], []],
|
'quantize': [[], []],
|
||||||
'phonemize': [[], []],
|
'phonemize': [[], []],
|
||||||
}
|
}
|
||||||
|
|
||||||
for file in tqdm(segments, desc="Parsing segments"):
|
for file in tqdm(segments, desc="Parsing segments"):
|
||||||
|
extension = os.path.splitext(file)[-1][1:]
|
||||||
result = segments[file]
|
result = segments[file]
|
||||||
path = f'{indir}/audio/{file}'
|
path = f'{indir}/audio/{file}'
|
||||||
|
|
||||||
if not os.path.exists(path):
|
|
||||||
message = f"Missing segment, skipping... {file}"
|
|
||||||
print(message)
|
|
||||||
messages.append(message)
|
|
||||||
errored += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
text = result['text']
|
text = result['text']
|
||||||
lang = result['lang']
|
lang = result['lang']
|
||||||
language = result['language']
|
language = result['language']
|
||||||
|
@ -2573,28 +2639,20 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
||||||
|
|
||||||
normalized = normalizer(text) if normalize else text
|
normalized = normalizer(text) if normalize else text
|
||||||
|
|
||||||
if len(text) > 200:
|
if len(text) > MAX_TRAINING_CHAR_LENGTH:
|
||||||
message = f"Text length too long (200 < {len(text)}), skipping... {file}"
|
message = f"Text length too long ({MAX_TRAINING_CHAR_LENGTH} < {len(text)}), skipping... {file}"
|
||||||
print(message)
|
print(message)
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
errored += 1
|
errored += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
waveform, sample_rate = torchaudio.load(path)
|
# num_channels, num_frames = waveform.shape
|
||||||
num_channels, num_frames = waveform.shape
|
#duration = num_frames / sample_rate
|
||||||
duration = num_frames / sample_rate
|
|
||||||
|
|
||||||
error = validate_waveform( waveform, sample_rate )
|
|
||||||
if error:
|
|
||||||
message = f"{error}, skipping... {file}"
|
|
||||||
print(message)
|
|
||||||
messages.append(message)
|
|
||||||
errored += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
culled = len(text) < text_length
|
culled = len(text) < text_length
|
||||||
if not culled and audio_length > 0:
|
#if not culled and audio_length > 0:
|
||||||
culled = duration < audio_length
|
# culled = duration < audio_length
|
||||||
|
|
||||||
line = f'audio/{file}|{phonemes if phonemize and phonemes else text}'
|
line = f'audio/{file}|{phonemes if phonemize and phonemes else text}'
|
||||||
|
|
||||||
|
@ -2605,17 +2663,8 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
||||||
|
|
||||||
os.makedirs(f'{indir}/valle/', exist_ok=True)
|
os.makedirs(f'{indir}/valle/', exist_ok=True)
|
||||||
|
|
||||||
qnt_file = f'{indir}/valle/{file.replace(".wav",".qnt.pt")}'
|
#phn_file = f'{indir}/valle/{file.replace(f".{extension}",".phn.txt")}'
|
||||||
if not os.path.exists(qnt_file):
|
phn_file = f'./training/valle/data/{voice}/{file.replace(f".{extension}",".phn.txt")}'
|
||||||
jobs['quantize'][0].append(qnt_file)
|
|
||||||
jobs['quantize'][1].append((waveform, sample_rate))
|
|
||||||
"""
|
|
||||||
quantized = valle_quantize( waveform, sample_rate ).cpu()
|
|
||||||
torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}')
|
|
||||||
print("Quantized:", file)
|
|
||||||
"""
|
|
||||||
|
|
||||||
phn_file = f'{indir}/valle/{file.replace(".wav",".phn.txt")}'
|
|
||||||
if not os.path.exists(phn_file):
|
if not os.path.exists(phn_file):
|
||||||
jobs['phonemize'][0].append(phn_file)
|
jobs['phonemize'][0].append(phn_file)
|
||||||
jobs['phonemize'][1].append(normalized)
|
jobs['phonemize'][1].append(normalized)
|
||||||
|
@ -2625,13 +2674,46 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
||||||
print("Phonemized:", file, normalized, text)
|
print("Phonemized:", file, normalized, text)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
#qnt_file = f'{indir}/valle/{file.replace(f".{extension}",".qnt.pt")}'
|
||||||
|
qnt_file = f'./training/valle/data/{voice}/{file.replace(f".{extension}",".qnt.pt")}'
|
||||||
|
if 'error' not in result:
|
||||||
|
if not quantize_in_memory and not os.path.exists(path):
|
||||||
|
message = f"Missing segment, skipping... {file}"
|
||||||
|
print(message)
|
||||||
|
messages.append(message)
|
||||||
|
errored += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not os.path.exists(qnt_file):
|
||||||
|
waveform = None
|
||||||
|
if 'waveform' in result:
|
||||||
|
waveform, sample_rate = result['waveform']
|
||||||
|
elif os.path.exists(path):
|
||||||
|
waveform, sample_rate = torchaudio.load(path)
|
||||||
|
error = validate_waveform( waveform, sample_rate )
|
||||||
|
if error:
|
||||||
|
message = f"{error}, skipping... {file}"
|
||||||
|
print(message)
|
||||||
|
messages.append(message)
|
||||||
|
errored += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if waveform is not None:
|
||||||
|
jobs['quantize'][0].append(qnt_file)
|
||||||
|
jobs['quantize'][1].append((waveform, sample_rate))
|
||||||
|
"""
|
||||||
|
quantized = valle_quantize( waveform, sample_rate ).cpu()
|
||||||
|
torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}')
|
||||||
|
print("Quantized:", file)
|
||||||
|
"""
|
||||||
|
|
||||||
for i in tqdm(range(len(jobs['quantize'][0])), desc="Quantizing"):
|
for i in tqdm(range(len(jobs['quantize'][0])), desc="Quantizing"):
|
||||||
qnt_file = jobs['quantize'][0][i]
|
qnt_file = jobs['quantize'][0][i]
|
||||||
waveform, sample_rate = jobs['quantize'][1][i]
|
waveform, sample_rate = jobs['quantize'][1][i]
|
||||||
|
|
||||||
quantized = valle_quantize( waveform, sample_rate ).cpu()
|
quantized = valle_quantize( waveform, sample_rate ).cpu()
|
||||||
torch.save(quantized, qnt_file)
|
torch.save(quantized, qnt_file)
|
||||||
print("Quantized:", qnt_file)
|
#print("Quantized:", qnt_file)
|
||||||
|
|
||||||
for i in tqdm(range(len(jobs['phonemize'][0])), desc="Phonemizing"):
|
for i in tqdm(range(len(jobs['phonemize'][0])), desc="Phonemizing"):
|
||||||
phn_file = jobs['phonemize'][0][i]
|
phn_file = jobs['phonemize'][0][i]
|
||||||
|
@ -2640,7 +2722,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
||||||
try:
|
try:
|
||||||
phonemized = valle_phonemize( normalized )
|
phonemized = valle_phonemize( normalized )
|
||||||
open(phn_file, 'w', encoding='utf-8').write(" ".join(phonemized))
|
open(phn_file, 'w', encoding='utf-8').write(" ".join(phonemized))
|
||||||
print("Phonemized:", phn_file)
|
#print("Phonemized:", phn_file)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
message = f"Failed to phonemize: {phn_file}: {normalized}"
|
message = f"Failed to phonemize: {phn_file}: {normalized}"
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
|
@ -2980,7 +3062,7 @@ def get_voice( name, dir=get_voice_dir(), load_latents=True ):
|
||||||
voice = voice + list(glob(f'{subj}/*.pth'))
|
voice = voice + list(glob(f'{subj}/*.pth'))
|
||||||
return sorted( voice )
|
return sorted( voice )
|
||||||
|
|
||||||
def get_voice_list(dir=get_voice_dir(), append_defaults=False):
|
def get_voice_list(dir=get_voice_dir(), append_defaults=False, extensions=["wav", "mp3", "flac", "pth"]):
|
||||||
defaults = [ "random", "microphone" ]
|
defaults = [ "random", "microphone" ]
|
||||||
os.makedirs(dir, exist_ok=True)
|
os.makedirs(dir, exist_ok=True)
|
||||||
#res = sorted([d for d in os.listdir(dir) if d not in defaults and os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ])
|
#res = sorted([d for d in os.listdir(dir) if d not in defaults and os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ])
|
||||||
|
@ -2993,7 +3075,7 @@ def get_voice_list(dir=get_voice_dir(), append_defaults=False):
|
||||||
continue
|
continue
|
||||||
if len(os.listdir(os.path.join(dir, name))) == 0:
|
if len(os.listdir(os.path.join(dir, name))) == 0:
|
||||||
continue
|
continue
|
||||||
files = get_voice( name, dir=dir )
|
files = get_voice( name, dir=dir, extensions=extensions )
|
||||||
|
|
||||||
if len(files) > 0:
|
if len(files) > 0:
|
||||||
res.append(name)
|
res.append(name)
|
||||||
|
@ -3001,7 +3083,7 @@ def get_voice_list(dir=get_voice_dir(), append_defaults=False):
|
||||||
for subdir in os.listdir(f'{dir}/{name}'):
|
for subdir in os.listdir(f'{dir}/{name}'):
|
||||||
if not os.path.isdir(f'{dir}/{name}/{subdir}'):
|
if not os.path.isdir(f'{dir}/{name}/{subdir}'):
|
||||||
continue
|
continue
|
||||||
files = get_voice( f'{name}/{subdir}', dir=dir )
|
files = get_voice( f'{name}/{subdir}', dir=dir, extensions=extensions )
|
||||||
if len(files) == 0:
|
if len(files) == 0:
|
||||||
continue
|
continue
|
||||||
res.append(f'{name}/{subdir}')
|
res.append(f'{name}/{subdir}')
|
||||||
|
|
Loading…
Reference in New Issue
Block a user