it slipped my mind that technically DAC can be used at any sample rate, since it models waveforms; make it a config YAML option to allow this behavior

This commit is contained in:
mrq 2024-04-19 18:36:54 -05:00
parent 00804a47e9
commit a8ffa88844
3 changed files with 84 additions and 43 deletions

View File

@ -8,7 +8,7 @@ from pathlib import Path
from vall_e.emb.g2p import encode as valle_phonemize
from vall_e.emb.qnt import encode as valle_quantize, _replace_file_extension
input_audio = "voices_4"
input_audio = "voices"
input_metadata = "metadata"
output_dataset = "training"
@ -34,7 +34,11 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
print("Does not exist:", metadata_path)
continue
metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read())
try:
metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read())
except Exception as e:
print("Failed to load metadata:", metadata_path, e)
continue
txts = []
wavs = []
@ -51,42 +55,73 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
waveform, sample_rate = None, None
language = metadata[filename]["language"] if "language" in metadata[filename] else "english"
for segment in metadata[filename]["segments"]:
id = pad(segment['id'], 4)
if len(metadata[filename]["segments"]) == 0:
id = pad(0, 4)
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}_{id}.{extension}')
text = metadata[filename]["text"]
if len(text) == 0:
continue
if _replace_file_extension(outpath, ".json").exists() and _replace_file_extension(outpath, ".dac").exists():
continue
if waveform is None:
waveform, sample_rate = torchaudio.load(inpath)
start = int(segment['start'] * sample_rate)
end = int(segment['end'] * sample_rate)
if start < 0:
start = 0
if end >= waveform.shape[-1]:
end = waveform.shape[-1] - 1
if not _replace_file_extension(outpath, ".json").exists():
txts.append((
outpath,
segment["text"],
text,
language,
))
if not _replace_file_extension(outpath, ".dac").exists():
if waveform is None:
waveform, sample_rate = torchaudio.load(inpath)
wavs.append((
outpath,
waveform[:, start:end],
waveform,
sample_rate
))
else:
for segment in metadata[filename]["segments"]:
id = pad(segment['id'], 4)
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}_{id}.{extension}')
if _replace_file_extension(outpath, ".json").exists() and _replace_file_extension(outpath, ".dac").exists():
continue
if not _replace_file_extension(outpath, ".json").exists():
txts.append((
outpath,
segment["text"],
language,
))
if not _replace_file_extension(outpath, ".dac").exists():
if waveform is None:
waveform, sample_rate = torchaudio.load(inpath)
start = int(segment['start'] * sample_rate)
end = int(segment['end'] * sample_rate)
if start < 0:
start = 0
if end >= waveform.shape[-1]:
end = waveform.shape[-1] - 1
if end - start < 0:
continue
wavs.append((
outpath,
waveform[:, start:end],
sample_rate
))
for job in tqdm(txts, desc=f"Phonemizing: {speaker_id}"):
outpath, text, language = job
phones = valle_phonemize(text)
data = {
"text": text,
"text": text.strip(),
"phonemes": phones,
"language": language,
}
@ -98,5 +133,5 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
qnt = valle_quantize(waveform, sr=sample_rate, device=device)
qnt.save(_replace_file_extension(outpath, ".dac"))
except Exception as e:
print(f"Failed to quantize: {speaker_id}")
print(f"Failed to quantize: {outpath}:", e)
continue

View File

@ -542,9 +542,8 @@ class Config(_Config):
fp8: FP8 = field(default_factory=lambda: FP8)
@property
def sample_rate(self):
return 24_000
sample_rate: int = 24_000
variable_sample_rate: bool = False
@property
def distributed(self):

View File

@ -7,7 +7,7 @@ import torchaudio
from functools import cache
from pathlib import Path
from typing import Union
from einops import rearrange
from torch import Tensor
@ -65,23 +65,22 @@ try:
recons = AudioSignal(recons, self.sample_rate)
# to-do, original implementation
"""
"""
resample_fn = recons.resample
loudness_fn = recons.loudness
if not hasattr(obj, "dummy") or not obj.dummy:
resample_fn = recons.resample
loudness_fn = recons.loudness
# If audio is > 10 minutes long, use the ffmpeg versions
if recons.signal_duration >= 10 * 60 * 60:
resample_fn = recons.ffmpeg_resample
loudness_fn = recons.ffmpeg_loudness
# If audio is > 10 minutes long, use the ffmpeg versions
if recons.signal_duration >= 10 * 60 * 60:
resample_fn = recons.ffmpeg_resample
loudness_fn = recons.ffmpeg_loudness
recons.normalize(obj.input_db)
resample_fn(obj.sample_rate)
recons = recons[..., : obj.original_length]
loudness_fn()
recons.audio_data = recons.audio_data.reshape(
-1, obj.channels, obj.original_length
)
recons.normalize(obj.input_db)
resample_fn(obj.sample_rate)
recons = recons[..., : obj.original_length]
loudness_fn()
recons.audio_data = recons.audio_data.reshape(
-1, obj.channels, obj.original_length
)
self.padding = original_padding
return recons
@ -89,7 +88,7 @@ try:
except Exception as e:
cfg.inference.use_dac = False
print(str(e))
@cache
def _load_encodec_model(device="cuda", levels=cfg.model.max_levels):
assert cfg.sample_rate == 24_000
@ -164,7 +163,11 @@ def _load_dac_model(device="cuda", levels=cfg.model.max_levels):
model = model.eval()
# extra metadata
model.sample_rate = cfg.sample_rate
# since DAC moreso models against waveforms, we can actually use a smaller sample rate
# updating it here will affect the sample rate the waveform is resampled to on encoding
if cfg.variable_sample_rate:
model.sample_rate = cfg.sample_rate
model.backend = "dac"
return model
@ -205,9 +208,10 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
# DAC uses a different pathway
if model.backend == "dac":
dummy = False
if metadata is None:
metadata = dict(
chunk_length=416,
chunk_length=120,
original_length=0,
input_db=-12,
channels=1,
@ -215,6 +219,7 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
padding=False,
dac_version='1.0.0',
)
dummy = True
# generate object with copied metadata
artifact = DACFile(
codes = codes,
@ -227,8 +232,10 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
padding = metadata["padding"] if isinstance(metadata, dict) else metadata.padding,
dac_version = metadata["dac_version"] if isinstance(metadata, dict) else metadata.dac_version,
)
artifact.dummy = dummy
return model.decompress(artifact, verbose=False).audio_data[0], artifact.sample_rate
# to-do: inject the sample rate encoded at, because we can actually decouple
return CodecMixin_decompress(model, artifact, verbose=False).audio_data[0], artifact.sample_rate
kwargs = {}