it slipped my mind that technically DAC can be used at any sample rate, since it models waveforms; make it a config YAML option to allow this behavior

This commit is contained in:
mrq 2024-04-19 18:36:54 -05:00
parent 00804a47e9
commit a8ffa88844
3 changed files with 84 additions and 43 deletions

View File

@ -8,7 +8,7 @@ from pathlib import Path
from vall_e.emb.g2p import encode as valle_phonemize from vall_e.emb.g2p import encode as valle_phonemize
from vall_e.emb.qnt import encode as valle_quantize, _replace_file_extension from vall_e.emb.qnt import encode as valle_quantize, _replace_file_extension
input_audio = "voices_4" input_audio = "voices"
input_metadata = "metadata" input_metadata = "metadata"
output_dataset = "training" output_dataset = "training"
@ -34,7 +34,11 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
print("Does not exist:", metadata_path) print("Does not exist:", metadata_path)
continue continue
try:
metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read()) metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read())
except Exception as e:
print("Failed to load metadata:", metadata_path, e)
continue
txts = [] txts = []
wavs = [] wavs = []
@ -51,6 +55,34 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
waveform, sample_rate = None, None waveform, sample_rate = None, None
language = metadata[filename]["language"] if "language" in metadata[filename] else "english" language = metadata[filename]["language"] if "language" in metadata[filename] else "english"
if len(metadata[filename]["segments"]) == 0:
id = pad(0, 4)
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}_{id}.{extension}')
text = metadata[filename]["text"]
if len(text) == 0:
continue
if _replace_file_extension(outpath, ".json").exists() and _replace_file_extension(outpath, ".dac").exists():
continue
if not _replace_file_extension(outpath, ".json").exists():
txts.append((
outpath,
text,
language,
))
if not _replace_file_extension(outpath, ".dac").exists():
if waveform is None:
waveform, sample_rate = torchaudio.load(inpath)
wavs.append((
outpath,
waveform,
sample_rate
))
else:
for segment in metadata[filename]["segments"]: for segment in metadata[filename]["segments"]:
id = pad(segment['id'], 4) id = pad(segment['id'], 4)
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}_{id}.{extension}') outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}_{id}.{extension}')
@ -58,6 +90,14 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
if _replace_file_extension(outpath, ".json").exists() and _replace_file_extension(outpath, ".dac").exists(): if _replace_file_extension(outpath, ".json").exists() and _replace_file_extension(outpath, ".dac").exists():
continue continue
if not _replace_file_extension(outpath, ".json").exists():
txts.append((
outpath,
segment["text"],
language,
))
if not _replace_file_extension(outpath, ".dac").exists():
if waveform is None: if waveform is None:
waveform, sample_rate = torchaudio.load(inpath) waveform, sample_rate = torchaudio.load(inpath)
@ -69,14 +109,9 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
if end >= waveform.shape[-1]: if end >= waveform.shape[-1]:
end = waveform.shape[-1] - 1 end = waveform.shape[-1] - 1
if not _replace_file_extension(outpath, ".json").exists(): if end - start < 0:
txts.append(( continue
outpath,
segment["text"],
language,
))
if not _replace_file_extension(outpath, ".dac").exists():
wavs.append(( wavs.append((
outpath, outpath,
waveform[:, start:end], waveform[:, start:end],
@ -86,7 +121,7 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
outpath, text, language = job outpath, text, language = job
phones = valle_phonemize(text) phones = valle_phonemize(text)
data = { data = {
"text": text, "text": text.strip(),
"phonemes": phones, "phonemes": phones,
"language": language, "language": language,
} }
@ -98,5 +133,5 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
qnt = valle_quantize(waveform, sr=sample_rate, device=device) qnt = valle_quantize(waveform, sr=sample_rate, device=device)
qnt.save(_replace_file_extension(outpath, ".dac")) qnt.save(_replace_file_extension(outpath, ".dac"))
except Exception as e: except Exception as e:
print(f"Failed to quantize: {speaker_id}") print(f"Failed to quantize: {outpath}:", e)
continue continue

View File

@ -542,9 +542,8 @@ class Config(_Config):
fp8: FP8 = field(default_factory=lambda: FP8) fp8: FP8 = field(default_factory=lambda: FP8)
@property sample_rate: int = 24_000
def sample_rate(self): variable_sample_rate: bool = False
return 24_000
@property @property
def distributed(self): def distributed(self):

View File

@ -7,7 +7,7 @@ import torchaudio
from functools import cache from functools import cache
from pathlib import Path from pathlib import Path
from typing import Union
from einops import rearrange from einops import rearrange
from torch import Tensor from torch import Tensor
@ -65,8 +65,7 @@ try:
recons = AudioSignal(recons, self.sample_rate) recons = AudioSignal(recons, self.sample_rate)
# to-do, original implementation # to-do, original implementation
""" if not hasattr(obj, "dummy") or not obj.dummy:
"""
resample_fn = recons.resample resample_fn = recons.resample
loudness_fn = recons.loudness loudness_fn = recons.loudness
@ -89,7 +88,7 @@ try:
except Exception as e: except Exception as e:
cfg.inference.use_dac = False cfg.inference.use_dac = False
print(str(e))
@cache @cache
def _load_encodec_model(device="cuda", levels=cfg.model.max_levels): def _load_encodec_model(device="cuda", levels=cfg.model.max_levels):
assert cfg.sample_rate == 24_000 assert cfg.sample_rate == 24_000
@ -164,6 +163,10 @@ def _load_dac_model(device="cuda", levels=cfg.model.max_levels):
model = model.eval() model = model.eval()
# extra metadata # extra metadata
# since DAC moreso models against waveforms, we can actually use a smaller sample rate
# updating it here will affect the sample rate the waveform is resampled to on encoding
if cfg.variable_sample_rate:
model.sample_rate = cfg.sample_rate model.sample_rate = cfg.sample_rate
model.backend = "dac" model.backend = "dac"
@ -205,9 +208,10 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
# DAC uses a different pathway # DAC uses a different pathway
if model.backend == "dac": if model.backend == "dac":
dummy = False
if metadata is None: if metadata is None:
metadata = dict( metadata = dict(
chunk_length=416, chunk_length=120,
original_length=0, original_length=0,
input_db=-12, input_db=-12,
channels=1, channels=1,
@ -215,6 +219,7 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
padding=False, padding=False,
dac_version='1.0.0', dac_version='1.0.0',
) )
dummy = True
# generate object with copied metadata # generate object with copied metadata
artifact = DACFile( artifact = DACFile(
codes = codes, codes = codes,
@ -227,8 +232,10 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
padding = metadata["padding"] if isinstance(metadata, dict) else metadata.padding, padding = metadata["padding"] if isinstance(metadata, dict) else metadata.padding,
dac_version = metadata["dac_version"] if isinstance(metadata, dict) else metadata.dac_version, dac_version = metadata["dac_version"] if isinstance(metadata, dict) else metadata.dac_version,
) )
artifact.dummy = dummy
return model.decompress(artifact, verbose=False).audio_data[0], artifact.sample_rate # to-do: inject the sample rate encoded at, because we can actually decouple
return CodecMixin_decompress(model, artifact, verbose=False).audio_data[0], artifact.sample_rate
kwargs = {} kwargs = {}