made initialization faster if there's a lot of voice files (because glob fucking sucks), commiting changes buried on my training rig

This commit is contained in:
mrq 2023-08-21 03:31:49 +00:00
parent 91a0c495ff
commit 72a38ff2fc
3 changed files with 201 additions and 118 deletions

@ -1 +1 @@
Subproject commit 5ff00bf3bfa97e2c8e9f166b920273f83ac9d8f0
Subproject commit cbd3c95c42ac1da9772f61b9895954ee693075c9

View File

@ -8,3 +8,4 @@ voicefixer
psutil
phonemizer
pydantic==1.10.11
websockets

View File

@ -45,7 +45,7 @@ from tortoise.utils.device import get_device_name, set_device_name, get_device_c
MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"
WHISPER_MODELS = ["tiny", "base", "small", "medium", "large"]
WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"]
WHISPER_BACKENDS = ["openai/whisper", "lightmare/whispercpp", "m-bain/whisperx"]
VOCODERS = ['univnet', 'bigvgan_base_24khz_100band', 'bigvgan_24khz_100band']
@ -61,12 +61,15 @@ RESAMPLERS = {}
MIN_TRAINING_DURATION = 0.6
MAX_TRAINING_DURATION = 11.6097505669
MAX_TRAINING_CHAR_LENGTH = 200
VALLE_ENABLED = False
BARK_ENABLED = False
VERBOSE_DEBUG = True
import traceback
try:
from whisper.normalizers.english import EnglishTextNormalizer
from whisper.normalizers.basic import BasicTextNormalizer
@ -75,7 +78,7 @@ try:
print("Whisper detected")
except Exception as e:
if VERBOSE_DEBUG:
print("Error:", e)
print(traceback.format_exc())
pass
try:
@ -90,12 +93,14 @@ try:
VALLE_ENABLED = True
except Exception as e:
if VERBOSE_DEBUG:
print("Error:", e)
print(traceback.format_exc())
pass
if VALLE_ENABLED:
TTSES.append('vall-e')
# torchaudio.set_audio_backend('soundfile')
try:
import bark
from bark import text_to_semantic
@ -109,35 +114,10 @@ try:
BARK_ENABLED = True
except Exception as e:
if VERBOSE_DEBUG:
print("Error:", e)
print(traceback.format_exc())
pass
if BARK_ENABLED:
try:
from vocos import Vocos
VOCOS_ENABLED = True
print("Vocos detected")
except Exception as e:
if VERBOSE_DEBUG:
print("Error:", e)
VOCOS_ENABLED = False
try:
from hubert.hubert_manager import HuBERTManager
from hubert.pre_kmeans_hubert import CustomHubert
from hubert.customtokenizer import CustomTokenizer
hubert_manager = HuBERTManager()
hubert_manager.make_sure_hubert_installed()
hubert_manager.make_sure_tokenizer_installed()
HUBERT_ENABLED = True
print("HuBERT detected")
except Exception as e:
if VERBOSE_DEBUG:
print("Error:", e)
HUBERT_ENABLED = False
TTSES.append('bark')
def semantic_to_audio_tokens(
@ -181,7 +161,32 @@ if BARK_ENABLED:
self.device = get_device_name()
if VOCOS_ENABLED:
try:
from vocos import Vocos
self.vocos_enabled = True
print("Vocos detected")
except Exception as e:
if VERBOSE_DEBUG:
print(traceback.format_exc())
self.vocos_enabled = False
try:
from hubert.hubert_manager import HuBERTManager
from hubert.pre_kmeans_hubert import CustomHubert
from hubert.customtokenizer import CustomTokenizer
hubert_manager = HuBERTManager()
hubert_manager.make_sure_hubert_installed()
hubert_manager.make_sure_tokenizer_installed()
self.hubert_enabled = True
print("HuBERT detected")
except Exception as e:
if VERBOSE_DEBUG:
print(traceback.format_exc())
self.hubert_enabled = False
if self.vocos_enabled:
self.vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz").to(self.device)
def create_voice( self, voice ):
@ -238,7 +243,7 @@ if BARK_ENABLED:
# generate semantic tokens
if HUBERT_ENABLED:
if self.hubert_enabled:
wav = wav.to(self.device)
# Extract discrete codes from EnCodec
@ -426,7 +431,7 @@ def generate_bark(**kwargs):
idx_cache = {}
for i, file in enumerate(os.listdir(outdir)):
filename = os.path.basename(file)
extension = os.path.splitext(filename)[1]
extension = os.path.splitext(filename)[-1][1:]
if extension != ".json" and extension != ".wav":
continue
match = re.findall(rf"^{cleanup_voice_name(voice)}_(\d+)(?:.+?)?{extension}$", filename)
@ -672,18 +677,23 @@ def generate_valle(**kwargs):
voice_cache = {}
def fetch_voice( voice ):
if voice in voice_cache:
return voice_cache[voice]
voice_dir = f'./training/{voice}/audio/'
if not os.path.isdir(voice_dir):
if not os.path.isdir(voice_dir) or len(os.listdir(voice_dir)) == 0:
voice_dir = f'./voices/{voice}/'
files = [ f'{voice_dir}/{d}' for d in os.listdir(voice_dir) if d[-4:] == ".wav" ]
# return files
return random.choice(files)
voice_cache[voice] = random.choice(files)
return voice_cache[voice]
def get_settings( override=None ):
settings = {
'ar_temp': float(parameters['temperature']),
'nar_temp': float(parameters['temperature']),
'max_ar_samples': parameters['num_autoregressive_samples'],
'max_ar_steps': parameters['num_autoregressive_samples'],
}
# could be better to just do a ternary on everything above, but i am not a professional
@ -697,7 +707,7 @@ def generate_valle(**kwargs):
continue
settings[k] = override[k]
settings['reference'] = fetch_voice(voice=selected_voice)
settings['references'] = [ fetch_voice(voice=selected_voice) for _ in range(3) ]
return settings
if not parameters['delimiter']:
@ -723,7 +733,7 @@ def generate_valle(**kwargs):
idx_cache = {}
for i, file in enumerate(os.listdir(outdir)):
filename = os.path.basename(file)
extension = os.path.splitext(filename)[1]
extension = os.path.splitext(filename)[-1][1:]
if extension != ".json" and extension != ".wav":
continue
match = re.findall(rf"^{voice}_(\d+)(?:.+?)?{extension}$", filename)
@ -783,11 +793,14 @@ def generate_valle(**kwargs):
except Exception as e:
raise Exception("Prompt settings editing requested, but received invalid JSON")
settings = get_settings( override=override )
reference = settings['reference']
settings.pop("reference")
name = get_name(line=line, candidate=0)
gen = tts.inference(cut_text, reference, **settings )
settings = get_settings( override=override )
references = settings['references']
settings.pop("references")
settings['out_path'] = f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav'
gen = tts.inference(cut_text, references, **settings )
run_time = time.time()-start_time
print(f"Generating line took {run_time} seconds")
@ -805,7 +818,7 @@ def generate_valle(**kwargs):
# save here in case some error happens mid-batch
#torchaudio.save(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu(), sr)
soundfile.write(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu()[0,0], sr)
#soundfile.write(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav', wav.cpu()[0,0], sr)
wav, sr = torchaudio.load(f'{outdir}/{cleanup_voice_name(voice)}_{name}.wav')
audio_cache[name] = {
@ -1085,7 +1098,7 @@ def generate_tortoise(**kwargs):
idx_cache = {}
for i, file in enumerate(os.listdir(outdir)):
filename = os.path.basename(file)
extension = os.path.splitext(filename)[1]
extension = os.path.splitext(filename)[-1][1:]
if extension != ".json" and extension != ".wav":
continue
match = re.findall(rf"^{voice}_(\d+)(?:.+?)?{extension}$", filename)
@ -1605,30 +1618,18 @@ class TrainingState():
if args.tts_backend == "vall-e":
keys['lrs'] = [
'ar.lr', 'nar.lr',
'ar-half.lr', 'nar-half.lr',
'ar-quarter.lr', 'nar-quarter.lr',
]
keys['losses'] = [
'ar.loss', 'nar.loss', 'ar+nar.loss',
'ar-half.loss', 'nar-half.loss', 'ar-half+nar-half.loss',
'ar-quarter.loss', 'nar-quarter.loss', 'ar-quarter+nar-quarter.loss',
# 'ar.loss', 'nar.loss', 'ar+nar.loss',
'ar.loss.nll', 'nar.loss.nll',
'ar-half.loss.nll', 'nar-half.loss.nll',
'ar-quarter.loss.nll', 'nar-quarter.loss.nll',
]
keys['accuracies'] = [
'ar.loss.acc', 'nar.loss.acc',
'ar-half.loss.acc', 'nar-half.loss.acc',
'ar-quarter.loss.acc', 'nar-quarter.loss.acc',
'ar.stats.acc', 'nar.loss.acc',
]
keys['precisions'] = [
'ar.loss.precision', 'nar.loss.precision',
'ar-half.loss.precision', 'nar-half.loss.precision',
'ar-quarter.loss.precision', 'nar-quarter.loss.precision',
]
keys['grad_norms'] = ['ar.grad_norm', 'nar.grad_norm', 'ar-half.grad_norm', 'nar-half.grad_norm', 'ar-quarter.grad_norm', 'nar-quarter.grad_norm']
keys['precisions'] = [ 'ar.loss.precision', 'nar.loss.precision', ]
keys['grad_norms'] = ['ar.grad_norm', 'nar.grad_norm']
for k in keys['lrs']:
if k not in self.info:
@ -1746,7 +1747,8 @@ class TrainingState():
if args.tts_backend == "tortoise":
logs = sorted([f'{self.training_dir}/finetune/{d}' for d in os.listdir(f'{self.training_dir}/finetune/') if d[-4:] == ".log" ])
else:
logs = sorted([f'{self.training_dir}/logs/{d}/log.txt' for d in os.listdir(f'{self.training_dir}/logs/') ])
log_dir = "logs"
logs = sorted([f'{self.training_dir}/{log_dir}/{d}/log.txt' for d in os.listdir(f'{self.training_dir}/{log_dir}/') ])
if update:
logs = [logs[-1]]
@ -2220,6 +2222,8 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
indir = f'./training/{voice}/'
infile = f'{indir}/whisper.json'
quantize_in_memory = args.tts_backend == "vall-e"
os.makedirs(f'{indir}/audio/', exist_ok=True)
TARGET_SAMPLE_RATE = 22050
@ -2245,13 +2249,24 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
continue
results[basename] = result
if not quantize_in_memory:
waveform, sample_rate = torchaudio.load(file)
# resample to the input rate, since it'll get resampled for training anyways
# this should also "help" increase throughput a bit when filling the dataloaders
waveform, sample_rate = resample(waveform, sample_rate, TARGET_SAMPLE_RATE)
if waveform.shape[0] == 2:
waveform = waveform[:1]
torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate, encoding="PCM_S", bits_per_sample=16)
try:
kwargs = {}
if basename[-4:] == ".wav":
kwargs['encoding'] = "PCM_S"
kwargs['bits_per_sample'] = 16
torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate, **kwargs)
except Exception as e:
print(e)
with open(infile, 'w', encoding="utf-8") as f:
f.write(json.dumps(results, indent='\t'))
@ -2317,6 +2332,9 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul
segments = 0
for filename in results:
path = f'./voices/{voice}/{filename}'
extension = os.path.splitext(filename)[-1][1:]
out_extension = extension # "wav"
if not os.path.exists(path):
path = f'./training/{voice}/{filename}'
@ -2333,7 +2351,7 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul
duration = num_frames / sample_rate
for segment in result['segments']:
file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")
file = filename.replace(f".{extension}", f"_{pad(segment['id'], 4)}.{out_extension}")
sliced, error = slice_waveform( waveform, sample_rate, segment['start'] + start_offset, segment['end'] + end_offset, trim_silence )
if error:
@ -2341,12 +2359,17 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul
print(message)
messages.append(message)
continue
sliced, _ = resample( sliced, sample_rate, TARGET_SAMPLE_RATE )
# sliced, _ = resample( sliced, sample_rate, TARGET_SAMPLE_RATE )
if waveform.shape[0] == 2:
waveform = waveform[:1]
torchaudio.save(f"{indir}/audio/{file}", sliced, TARGET_SAMPLE_RATE, encoding="PCM_S", bits_per_sample=16)
kwargs = {}
if file[-4:] == ".wav":
kwargs['encoding'] = "PCM_S"
kwargs['bits_per_sample'] = 16
torchaudio.save(f"{indir}/audio/{file}", sliced, TARGET_SAMPLE_RATE, **kwargs)
segments +=1
@ -2462,18 +2485,32 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
errored = 0
messages = []
normalize = True
normalize = False # True
phonemize = should_phonemize()
lines = { 'training': [], 'validation': [] }
segments = {}
quantize_in_memory = args.tts_backend == "vall-e"
if args.tts_backend != "tortoise":
text_length = 0
audio_length = 0
start_offset = -0.1
end_offset = 0.1
trim_silence = False
TARGET_SAMPLE_RATE = 22050
if args.tts_backend != "tortoise":
TARGET_SAMPLE_RATE = 24000
if tts:
TARGET_SAMPLE_RATE = tts.input_sample_rate
for filename in tqdm(results, desc="Parsing results"):
use_segment = use_segments
extension = os.path.splitext(filename)[-1][1:]
out_extension = extension # "wav"
result = results[filename]
lang = result['language']
language = LANGUAGES[lang] if lang in LANGUAGES else lang
@ -2481,8 +2518,8 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
# check if unsegmented text exceeds 200 characters
if not use_segment:
if len(result['text']) > 200:
message = f"Text length too long (200 < {len(result['text'])}), using segments: {filename}"
if len(result['text']) > MAX_TRAINING_CHAR_LENGTH:
message = f"Text length too long ({MAX_TRAINING_CHAR_LENGTH} < {len(result['text'])}), using segments: {filename}"
print(message)
messages.append(message)
use_segment = True
@ -2490,13 +2527,15 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
# check if unsegmented audio exceeds 11.6s
if not use_segment:
path = f'{indir}/audio/{filename}'
if not os.path.exists(path):
if not quantize_in_memory and not os.path.exists(path):
messages.append(f"Missing source audio: {filename}")
errored += 1
continue
metadata = torchaudio.info(path)
duration = metadata.num_frames / metadata.sample_rate
duration = 0
for segment in result['segments']:
duration = max(duration, result['segments'][segment]['end'])
if duration >= MAX_TRAINING_DURATION:
message = f"Audio too large, using segments: {filename}"
print(message)
@ -2511,13 +2550,13 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
if duration <= MIN_TRAINING_DURATION or MAX_TRAINING_DURATION <= duration:
continue
path = f'{indir}/audio/' + filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")
path = f'{indir}/audio/' + filename.replace(f".{extension}", f"_{pad(segment['id'], 4)}.{out_extension}")
if os.path.exists(path):
continue
exists = False
break
if not exists:
if not quantize_in_memory and not exists:
tmp = {}
tmp[filename] = result
print(f"Audio not segmented, segmenting: {filename}")
@ -2525,6 +2564,23 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
print(message)
messages = messages + message.split("\n")
waveform = None
if quantize_in_memory:
path = f'{indir}/audio/{filename}'
if not os.path.exists(path):
path = f'./voices/{voice}/{filename}'
if not os.path.exists(path):
message = f"Audio not found: {path}"
print(message)
messages.append(message)
#continue
else:
waveform = torchaudio.load(path)
waveform = resample(waveform[0], waveform[1], TARGET_SAMPLE_RATE)
if not use_segment:
segments[filename] = {
'text': result['text'],
@ -2533,13 +2589,18 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
'normalizer': normalizer,
'phonemes': result['phonemes'] if 'phonemes' in result else None
}
if waveform:
segments[filename]['waveform'] = waveform
else:
for segment in result['segments']:
duration = segment['end'] - segment['start']
if duration <= MIN_TRAINING_DURATION or MAX_TRAINING_DURATION <= duration:
continue
segments[filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav")] = {
file = filename.replace(f".{extension}", f"_{pad(segment['id'], 4)}.{out_extension}")
segments[file] = {
'text': segment['text'],
'lang': lang,
'language': language,
@ -2547,22 +2608,27 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
'phonemes': segment['phonemes'] if 'phonemes' in segment else None
}
if waveform:
sliced, error = slice_waveform( waveform[0], waveform[1], segment['start'] + start_offset, segment['end'] + end_offset, trim_silence )
if error:
message = f"{error}, skipping... {file}"
print(message)
messages.append(message)
segments[file]['error'] = error
#continue
else:
segments[file]['waveform'] = (sliced, waveform[1])
jobs = {
'quantize': [[], []],
'phonemize': [[], []],
}
for file in tqdm(segments, desc="Parsing segments"):
extension = os.path.splitext(file)[-1][1:]
result = segments[file]
path = f'{indir}/audio/{file}'
if not os.path.exists(path):
message = f"Missing segment, skipping... {file}"
print(message)
messages.append(message)
errored += 1
continue
text = result['text']
lang = result['lang']
language = result['language']
@ -2573,28 +2639,20 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
normalized = normalizer(text) if normalize else text
if len(text) > 200:
message = f"Text length too long (200 < {len(text)}), skipping... {file}"
if len(text) > MAX_TRAINING_CHAR_LENGTH:
message = f"Text length too long ({MAX_TRAINING_CHAR_LENGTH} < {len(text)}), skipping... {file}"
print(message)
messages.append(message)
errored += 1
continue
waveform, sample_rate = torchaudio.load(path)
num_channels, num_frames = waveform.shape
duration = num_frames / sample_rate
# num_channels, num_frames = waveform.shape
#duration = num_frames / sample_rate
error = validate_waveform( waveform, sample_rate )
if error:
message = f"{error}, skipping... {file}"
print(message)
messages.append(message)
errored += 1
continue
culled = len(text) < text_length
if not culled and audio_length > 0:
culled = duration < audio_length
#if not culled and audio_length > 0:
# culled = duration < audio_length
line = f'audio/{file}|{phonemes if phonemize and phonemes else text}'
@ -2605,17 +2663,8 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
os.makedirs(f'{indir}/valle/', exist_ok=True)
qnt_file = f'{indir}/valle/{file.replace(".wav",".qnt.pt")}'
if not os.path.exists(qnt_file):
jobs['quantize'][0].append(qnt_file)
jobs['quantize'][1].append((waveform, sample_rate))
"""
quantized = valle_quantize( waveform, sample_rate ).cpu()
torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}')
print("Quantized:", file)
"""
phn_file = f'{indir}/valle/{file.replace(".wav",".phn.txt")}'
#phn_file = f'{indir}/valle/{file.replace(f".{extension}",".phn.txt")}'
phn_file = f'./training/valle/data/{voice}/{file.replace(f".{extension}",".phn.txt")}'
if not os.path.exists(phn_file):
jobs['phonemize'][0].append(phn_file)
jobs['phonemize'][1].append(normalized)
@ -2625,13 +2674,46 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
print("Phonemized:", file, normalized, text)
"""
#qnt_file = f'{indir}/valle/{file.replace(f".{extension}",".qnt.pt")}'
qnt_file = f'./training/valle/data/{voice}/{file.replace(f".{extension}",".qnt.pt")}'
if 'error' not in result:
if not quantize_in_memory and not os.path.exists(path):
message = f"Missing segment, skipping... {file}"
print(message)
messages.append(message)
errored += 1
continue
if not os.path.exists(qnt_file):
waveform = None
if 'waveform' in result:
waveform, sample_rate = result['waveform']
elif os.path.exists(path):
waveform, sample_rate = torchaudio.load(path)
error = validate_waveform( waveform, sample_rate )
if error:
message = f"{error}, skipping... {file}"
print(message)
messages.append(message)
errored += 1
continue
if waveform is not None:
jobs['quantize'][0].append(qnt_file)
jobs['quantize'][1].append((waveform, sample_rate))
"""
quantized = valle_quantize( waveform, sample_rate ).cpu()
torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}')
print("Quantized:", file)
"""
for i in tqdm(range(len(jobs['quantize'][0])), desc="Quantizing"):
qnt_file = jobs['quantize'][0][i]
waveform, sample_rate = jobs['quantize'][1][i]
quantized = valle_quantize( waveform, sample_rate ).cpu()
torch.save(quantized, qnt_file)
print("Quantized:", qnt_file)
#print("Quantized:", qnt_file)
for i in tqdm(range(len(jobs['phonemize'][0])), desc="Phonemizing"):
phn_file = jobs['phonemize'][0][i]
@ -2640,7 +2722,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
try:
phonemized = valle_phonemize( normalized )
open(phn_file, 'w', encoding='utf-8').write(" ".join(phonemized))
print("Phonemized:", phn_file)
#print("Phonemized:", phn_file)
except Exception as e:
message = f"Failed to phonemize: {phn_file}: {normalized}"
messages.append(message)
@ -2980,7 +3062,7 @@ def get_voice( name, dir=get_voice_dir(), load_latents=True ):
voice = voice + list(glob(f'{subj}/*.pth'))
return sorted( voice )
def get_voice_list(dir=get_voice_dir(), append_defaults=False):
def get_voice_list(dir=get_voice_dir(), append_defaults=False, extensions=["wav", "mp3", "flac", "pth"]):
defaults = [ "random", "microphone" ]
os.makedirs(dir, exist_ok=True)
#res = sorted([d for d in os.listdir(dir) if d not in defaults and os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ])
@ -2993,7 +3075,7 @@ def get_voice_list(dir=get_voice_dir(), append_defaults=False):
continue
if len(os.listdir(os.path.join(dir, name))) == 0:
continue
files = get_voice( name, dir=dir )
files = get_voice( name, dir=dir, extensions=extensions )
if len(files) > 0:
res.append(name)
@ -3001,7 +3083,7 @@ def get_voice_list(dir=get_voice_dir(), append_defaults=False):
for subdir in os.listdir(f'{dir}/{name}'):
if not os.path.isdir(f'{dir}/{name}/{subdir}'):
continue
files = get_voice( f'{name}/{subdir}', dir=dir )
files = get_voice( f'{name}/{subdir}', dir=dir, extensions=extensions )
if len(files) == 0:
continue
res.append(f'{name}/{subdir}')