From a8ffa88844152336c52a3507f3498169849b6def Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 19 Apr 2024 18:36:54 -0500 Subject: [PATCH] it slipped my mind that technically DAC can be used at any sample rate, since it models waveforms; make it a config YAML option to allow this behavior --- scripts/process_old_dataaset.py | 73 ++++++++++++++++++++++++--------- vall_e/config.py | 5 +-- vall_e/emb/qnt.py | 49 ++++++++++++---------- 3 files changed, 84 insertions(+), 43 deletions(-) diff --git a/scripts/process_old_dataaset.py b/scripts/process_old_dataaset.py index 47d0941..928c043 100644 --- a/scripts/process_old_dataaset.py +++ b/scripts/process_old_dataaset.py @@ -8,7 +8,7 @@ from pathlib import Path from vall_e.emb.g2p import encode as valle_phonemize from vall_e.emb.qnt import encode as valle_quantize, _replace_file_extension -input_audio = "voices_4" +input_audio = "voices" input_metadata = "metadata" output_dataset = "training" @@ -34,7 +34,11 @@ for dataset_name in os.listdir(f'./{input_audio}/'): print("Does not exist:", metadata_path) continue - metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read()) + try: + metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read()) + except Exception as e: + print("Failed to load metadata:", metadata_path, e) + continue txts = [] wavs = [] @@ -51,42 +55,73 @@ for dataset_name in os.listdir(f'./{input_audio}/'): waveform, sample_rate = None, None language = metadata[filename]["language"] if "language" in metadata[filename] else "english" - for segment in metadata[filename]["segments"]: - id = pad(segment['id'], 4) + if len(metadata[filename]["segments"]) == 0: + id = pad(0, 4) outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}_{id}.{extension}') + text = metadata[filename]["text"] + + if len(text) == 0: + continue if _replace_file_extension(outpath, ".json").exists() and _replace_file_extension(outpath, ".dac").exists(): continue - if waveform is None: - waveform, sample_rate = torchaudio.load(inpath) - - start = int(segment['start'] * sample_rate) - end = int(segment['end'] * sample_rate) - - if start < 0: - start = 0 - if end >= waveform.shape[-1]: - end = waveform.shape[-1] - 1 - if not _replace_file_extension(outpath, ".json").exists(): txts.append(( outpath, - segment["text"], + text, language, )) if not _replace_file_extension(outpath, ".dac").exists(): + if waveform is None: + waveform, sample_rate = torchaudio.load(inpath) + wavs.append(( outpath, - waveform[:, start:end], + waveform, sample_rate )) + else: + for segment in metadata[filename]["segments"]: + id = pad(segment['id'], 4) + outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}_{id}.{extension}') + + if _replace_file_extension(outpath, ".json").exists() and _replace_file_extension(outpath, ".dac").exists(): + continue + + if not _replace_file_extension(outpath, ".json").exists(): + txts.append(( + outpath, + segment["text"], + language, + )) + + if not _replace_file_extension(outpath, ".dac").exists(): + if waveform is None: + waveform, sample_rate = torchaudio.load(inpath) + + start = int(segment['start'] * sample_rate) + end = int(segment['end'] * sample_rate) + + if start < 0: + start = 0 + if end >= waveform.shape[-1]: + end = waveform.shape[-1] - 1 + + if end - start < 0: + continue + + wavs.append(( + outpath, + waveform[:, start:end], + sample_rate + )) for job in tqdm(txts, desc=f"Phonemizing: {speaker_id}"): outpath, text, language = job phones = valle_phonemize(text) data = { - "text": text, + "text": text.strip(), "phonemes": phones, "language": language, } @@ -98,5 +133,5 @@ for dataset_name in os.listdir(f'./{input_audio}/'): qnt = valle_quantize(waveform, sr=sample_rate, device=device) qnt.save(_replace_file_extension(outpath, ".dac")) except Exception as e: - print(f"Failed to quantize: {speaker_id}") + print(f"Failed to quantize: {outpath}:", e) continue diff --git a/vall_e/config.py b/vall_e/config.py index 4f911ef..1081e77 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -542,9 +542,8 @@ class Config(_Config): fp8: FP8 = field(default_factory=lambda: FP8) - @property - def sample_rate(self): - return 24_000 + sample_rate: int = 24_000 + variable_sample_rate: bool = False @property def distributed(self): diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index c5af8ac..30b8bdd 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -7,7 +7,7 @@ import torchaudio from functools import cache from pathlib import Path - +from typing import Union from einops import rearrange from torch import Tensor @@ -65,23 +65,22 @@ try: recons = AudioSignal(recons, self.sample_rate) # to-do, original implementation - """ - """ - resample_fn = recons.resample - loudness_fn = recons.loudness - - # If audio is > 10 minutes long, use the ffmpeg versions - if recons.signal_duration >= 10 * 60 * 60: - resample_fn = recons.ffmpeg_resample - loudness_fn = recons.ffmpeg_loudness + if not hasattr(obj, "dummy") or not obj.dummy: + resample_fn = recons.resample + loudness_fn = recons.loudness + + # If audio is > 10 minutes long, use the ffmpeg versions + if recons.signal_duration >= 10 * 60 * 60: + resample_fn = recons.ffmpeg_resample + loudness_fn = recons.ffmpeg_loudness - recons.normalize(obj.input_db) - resample_fn(obj.sample_rate) - recons = recons[..., : obj.original_length] - loudness_fn() - recons.audio_data = recons.audio_data.reshape( - -1, obj.channels, obj.original_length - ) + recons.normalize(obj.input_db) + resample_fn(obj.sample_rate) + recons = recons[..., : obj.original_length] + loudness_fn() + recons.audio_data = recons.audio_data.reshape( + -1, obj.channels, obj.original_length + ) self.padding = original_padding return recons @@ -89,7 +88,7 @@ try: except Exception as e: cfg.inference.use_dac = False - + print(str(e)) @cache def _load_encodec_model(device="cuda", levels=cfg.model.max_levels): assert cfg.sample_rate == 24_000 @@ -164,7 +163,11 @@ def _load_dac_model(device="cuda", levels=cfg.model.max_levels): model = model.eval() # extra metadata - model.sample_rate = cfg.sample_rate + + # since DAC moreso models against waveforms, we can actually use a smaller sample rate + # updating it here will affect the sample rate the waveform is resampled to on encoding + if cfg.variable_sample_rate: + model.sample_rate = cfg.sample_rate model.backend = "dac" return model @@ -205,9 +208,10 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N # DAC uses a different pathway if model.backend == "dac": + dummy = False if metadata is None: metadata = dict( - chunk_length=416, + chunk_length=120, original_length=0, input_db=-12, channels=1, @@ -215,6 +219,7 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N padding=False, dac_version='1.0.0', ) + dummy = True # generate object with copied metadata artifact = DACFile( codes = codes, @@ -227,8 +232,10 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N padding = metadata["padding"] if isinstance(metadata, dict) else metadata.padding, dac_version = metadata["dac_version"] if isinstance(metadata, dict) else metadata.dac_version, ) + artifact.dummy = dummy - return model.decompress(artifact, verbose=False).audio_data[0], artifact.sample_rate + # to-do: inject the sample rate encoded at, because we can actually decouple + return CodecMixin_decompress(model, artifact, verbose=False).audio_data[0], artifact.sample_rate kwargs = {}