it slipped my mind that technically DAC can be used at any sample rate, since it models waveforms; make it a config YAML option to allow this behavior

This commit is contained in:
mrq 2024-04-19 18:36:54 -05:00
parent 00804a47e9
commit a8ffa88844
3 changed files with 84 additions and 43 deletions

View File

@ -8,7 +8,7 @@ from pathlib import Path
from vall_e.emb.g2p import encode as valle_phonemize from vall_e.emb.g2p import encode as valle_phonemize
from vall_e.emb.qnt import encode as valle_quantize, _replace_file_extension from vall_e.emb.qnt import encode as valle_quantize, _replace_file_extension
input_audio = "voices_4" input_audio = "voices"
input_metadata = "metadata" input_metadata = "metadata"
output_dataset = "training" output_dataset = "training"
@ -34,7 +34,11 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
print("Does not exist:", metadata_path) print("Does not exist:", metadata_path)
continue continue
metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read()) try:
metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read())
except Exception as e:
print("Failed to load metadata:", metadata_path, e)
continue
txts = [] txts = []
wavs = [] wavs = []
@ -51,42 +55,73 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
waveform, sample_rate = None, None waveform, sample_rate = None, None
language = metadata[filename]["language"] if "language" in metadata[filename] else "english" language = metadata[filename]["language"] if "language" in metadata[filename] else "english"
for segment in metadata[filename]["segments"]: if len(metadata[filename]["segments"]) == 0:
id = pad(segment['id'], 4) id = pad(0, 4)
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}_{id}.{extension}') outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}_{id}.{extension}')
text = metadata[filename]["text"]
if len(text) == 0:
continue
if _replace_file_extension(outpath, ".json").exists() and _replace_file_extension(outpath, ".dac").exists(): if _replace_file_extension(outpath, ".json").exists() and _replace_file_extension(outpath, ".dac").exists():
continue continue
if waveform is None:
waveform, sample_rate = torchaudio.load(inpath)
start = int(segment['start'] * sample_rate)
end = int(segment['end'] * sample_rate)
if start < 0:
start = 0
if end >= waveform.shape[-1]:
end = waveform.shape[-1] - 1
if not _replace_file_extension(outpath, ".json").exists(): if not _replace_file_extension(outpath, ".json").exists():
txts.append(( txts.append((
outpath, outpath,
segment["text"], text,
language, language,
)) ))
if not _replace_file_extension(outpath, ".dac").exists(): if not _replace_file_extension(outpath, ".dac").exists():
if waveform is None:
waveform, sample_rate = torchaudio.load(inpath)
wavs.append(( wavs.append((
outpath, outpath,
waveform[:, start:end], waveform,
sample_rate sample_rate
)) ))
else:
for segment in metadata[filename]["segments"]:
id = pad(segment['id'], 4)
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}_{id}.{extension}')
if _replace_file_extension(outpath, ".json").exists() and _replace_file_extension(outpath, ".dac").exists():
continue
if not _replace_file_extension(outpath, ".json").exists():
txts.append((
outpath,
segment["text"],
language,
))
if not _replace_file_extension(outpath, ".dac").exists():
if waveform is None:
waveform, sample_rate = torchaudio.load(inpath)
start = int(segment['start'] * sample_rate)
end = int(segment['end'] * sample_rate)
if start < 0:
start = 0
if end >= waveform.shape[-1]:
end = waveform.shape[-1] - 1
if end - start < 0:
continue
wavs.append((
outpath,
waveform[:, start:end],
sample_rate
))
for job in tqdm(txts, desc=f"Phonemizing: {speaker_id}"): for job in tqdm(txts, desc=f"Phonemizing: {speaker_id}"):
outpath, text, language = job outpath, text, language = job
phones = valle_phonemize(text) phones = valle_phonemize(text)
data = { data = {
"text": text, "text": text.strip(),
"phonemes": phones, "phonemes": phones,
"language": language, "language": language,
} }
@ -98,5 +133,5 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
qnt = valle_quantize(waveform, sr=sample_rate, device=device) qnt = valle_quantize(waveform, sr=sample_rate, device=device)
qnt.save(_replace_file_extension(outpath, ".dac")) qnt.save(_replace_file_extension(outpath, ".dac"))
except Exception as e: except Exception as e:
print(f"Failed to quantize: {speaker_id}") print(f"Failed to quantize: {outpath}:", e)
continue continue

View File

@ -542,9 +542,8 @@ class Config(_Config):
fp8: FP8 = field(default_factory=lambda: FP8) fp8: FP8 = field(default_factory=lambda: FP8)
@property sample_rate: int = 24_000
def sample_rate(self): variable_sample_rate: bool = False
return 24_000
@property @property
def distributed(self): def distributed(self):

View File

@ -7,7 +7,7 @@ import torchaudio
from functools import cache from functools import cache
from pathlib import Path from pathlib import Path
from typing import Union
from einops import rearrange from einops import rearrange
from torch import Tensor from torch import Tensor
@ -65,23 +65,22 @@ try:
recons = AudioSignal(recons, self.sample_rate) recons = AudioSignal(recons, self.sample_rate)
# to-do, original implementation # to-do, original implementation
""" if not hasattr(obj, "dummy") or not obj.dummy:
""" resample_fn = recons.resample
resample_fn = recons.resample loudness_fn = recons.loudness
loudness_fn = recons.loudness
# If audio is > 10 minutes long, use the ffmpeg versions
# If audio is > 10 minutes long, use the ffmpeg versions if recons.signal_duration >= 10 * 60 * 60:
if recons.signal_duration >= 10 * 60 * 60: resample_fn = recons.ffmpeg_resample
resample_fn = recons.ffmpeg_resample loudness_fn = recons.ffmpeg_loudness
loudness_fn = recons.ffmpeg_loudness
recons.normalize(obj.input_db) recons.normalize(obj.input_db)
resample_fn(obj.sample_rate) resample_fn(obj.sample_rate)
recons = recons[..., : obj.original_length] recons = recons[..., : obj.original_length]
loudness_fn() loudness_fn()
recons.audio_data = recons.audio_data.reshape( recons.audio_data = recons.audio_data.reshape(
-1, obj.channels, obj.original_length -1, obj.channels, obj.original_length
) )
self.padding = original_padding self.padding = original_padding
return recons return recons
@ -89,7 +88,7 @@ try:
except Exception as e: except Exception as e:
cfg.inference.use_dac = False cfg.inference.use_dac = False
print(str(e))
@cache @cache
def _load_encodec_model(device="cuda", levels=cfg.model.max_levels): def _load_encodec_model(device="cuda", levels=cfg.model.max_levels):
assert cfg.sample_rate == 24_000 assert cfg.sample_rate == 24_000
@ -164,7 +163,11 @@ def _load_dac_model(device="cuda", levels=cfg.model.max_levels):
model = model.eval() model = model.eval()
# extra metadata # extra metadata
model.sample_rate = cfg.sample_rate
# since DAC moreso models against waveforms, we can actually use a smaller sample rate
# updating it here will affect the sample rate the waveform is resampled to on encoding
if cfg.variable_sample_rate:
model.sample_rate = cfg.sample_rate
model.backend = "dac" model.backend = "dac"
return model return model
@ -205,9 +208,10 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
# DAC uses a different pathway # DAC uses a different pathway
if model.backend == "dac": if model.backend == "dac":
dummy = False
if metadata is None: if metadata is None:
metadata = dict( metadata = dict(
chunk_length=416, chunk_length=120,
original_length=0, original_length=0,
input_db=-12, input_db=-12,
channels=1, channels=1,
@ -215,6 +219,7 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
padding=False, padding=False,
dac_version='1.0.0', dac_version='1.0.0',
) )
dummy = True
# generate object with copied metadata # generate object with copied metadata
artifact = DACFile( artifact = DACFile(
codes = codes, codes = codes,
@ -227,8 +232,10 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
padding = metadata["padding"] if isinstance(metadata, dict) else metadata.padding, padding = metadata["padding"] if isinstance(metadata, dict) else metadata.padding,
dac_version = metadata["dac_version"] if isinstance(metadata, dict) else metadata.dac_version, dac_version = metadata["dac_version"] if isinstance(metadata, dict) else metadata.dac_version,
) )
artifact.dummy = dummy
return model.decompress(artifact, verbose=False).audio_data[0], artifact.sample_rate # to-do: inject the sample rate encoded at, because we can actually decouple
return CodecMixin_decompress(model, artifact, verbose=False).audio_data[0], artifact.sample_rate
kwargs = {} kwargs = {}