it slipped my mind that technically DAC can be used at any sample rate, since it models waveforms; make it a config YAML option to allow this behavior
This commit is contained in:
parent
00804a47e9
commit
a8ffa88844
|
@ -8,7 +8,7 @@ from pathlib import Path
|
|||
from vall_e.emb.g2p import encode as valle_phonemize
|
||||
from vall_e.emb.qnt import encode as valle_quantize, _replace_file_extension
|
||||
|
||||
input_audio = "voices_4"
|
||||
input_audio = "voices"
|
||||
input_metadata = "metadata"
|
||||
output_dataset = "training"
|
||||
|
||||
|
@ -34,7 +34,11 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
|
|||
print("Does not exist:", metadata_path)
|
||||
continue
|
||||
|
||||
metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read())
|
||||
try:
|
||||
metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read())
|
||||
except Exception as e:
|
||||
print("Failed to load metadata:", metadata_path, e)
|
||||
continue
|
||||
|
||||
txts = []
|
||||
wavs = []
|
||||
|
@ -51,42 +55,73 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
|
|||
waveform, sample_rate = None, None
|
||||
language = metadata[filename]["language"] if "language" in metadata[filename] else "english"
|
||||
|
||||
for segment in metadata[filename]["segments"]:
|
||||
id = pad(segment['id'], 4)
|
||||
if len(metadata[filename]["segments"]) == 0:
|
||||
id = pad(0, 4)
|
||||
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}_{id}.{extension}')
|
||||
text = metadata[filename]["text"]
|
||||
|
||||
if len(text) == 0:
|
||||
continue
|
||||
|
||||
if _replace_file_extension(outpath, ".json").exists() and _replace_file_extension(outpath, ".dac").exists():
|
||||
continue
|
||||
|
||||
if waveform is None:
|
||||
waveform, sample_rate = torchaudio.load(inpath)
|
||||
|
||||
start = int(segment['start'] * sample_rate)
|
||||
end = int(segment['end'] * sample_rate)
|
||||
|
||||
if start < 0:
|
||||
start = 0
|
||||
if end >= waveform.shape[-1]:
|
||||
end = waveform.shape[-1] - 1
|
||||
|
||||
if not _replace_file_extension(outpath, ".json").exists():
|
||||
txts.append((
|
||||
outpath,
|
||||
segment["text"],
|
||||
text,
|
||||
language,
|
||||
))
|
||||
|
||||
if not _replace_file_extension(outpath, ".dac").exists():
|
||||
if waveform is None:
|
||||
waveform, sample_rate = torchaudio.load(inpath)
|
||||
|
||||
wavs.append((
|
||||
outpath,
|
||||
waveform[:, start:end],
|
||||
waveform,
|
||||
sample_rate
|
||||
))
|
||||
else:
|
||||
for segment in metadata[filename]["segments"]:
|
||||
id = pad(segment['id'], 4)
|
||||
outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}_{id}.{extension}')
|
||||
|
||||
if _replace_file_extension(outpath, ".json").exists() and _replace_file_extension(outpath, ".dac").exists():
|
||||
continue
|
||||
|
||||
if not _replace_file_extension(outpath, ".json").exists():
|
||||
txts.append((
|
||||
outpath,
|
||||
segment["text"],
|
||||
language,
|
||||
))
|
||||
|
||||
if not _replace_file_extension(outpath, ".dac").exists():
|
||||
if waveform is None:
|
||||
waveform, sample_rate = torchaudio.load(inpath)
|
||||
|
||||
start = int(segment['start'] * sample_rate)
|
||||
end = int(segment['end'] * sample_rate)
|
||||
|
||||
if start < 0:
|
||||
start = 0
|
||||
if end >= waveform.shape[-1]:
|
||||
end = waveform.shape[-1] - 1
|
||||
|
||||
if end - start < 0:
|
||||
continue
|
||||
|
||||
wavs.append((
|
||||
outpath,
|
||||
waveform[:, start:end],
|
||||
sample_rate
|
||||
))
|
||||
for job in tqdm(txts, desc=f"Phonemizing: {speaker_id}"):
|
||||
outpath, text, language = job
|
||||
phones = valle_phonemize(text)
|
||||
data = {
|
||||
"text": text,
|
||||
"text": text.strip(),
|
||||
"phonemes": phones,
|
||||
"language": language,
|
||||
}
|
||||
|
@ -98,5 +133,5 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
|
|||
qnt = valle_quantize(waveform, sr=sample_rate, device=device)
|
||||
qnt.save(_replace_file_extension(outpath, ".dac"))
|
||||
except Exception as e:
|
||||
print(f"Failed to quantize: {speaker_id}")
|
||||
print(f"Failed to quantize: {outpath}:", e)
|
||||
continue
|
||||
|
|
|
@ -542,9 +542,8 @@ class Config(_Config):
|
|||
|
||||
fp8: FP8 = field(default_factory=lambda: FP8)
|
||||
|
||||
@property
|
||||
def sample_rate(self):
|
||||
return 24_000
|
||||
sample_rate: int = 24_000
|
||||
variable_sample_rate: bool = False
|
||||
|
||||
@property
|
||||
def distributed(self):
|
||||
|
|
|
@ -7,7 +7,7 @@ import torchaudio
|
|||
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
|
||||
from typing import Union
|
||||
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
@ -65,23 +65,22 @@ try:
|
|||
recons = AudioSignal(recons, self.sample_rate)
|
||||
|
||||
# to-do, original implementation
|
||||
"""
|
||||
"""
|
||||
resample_fn = recons.resample
|
||||
loudness_fn = recons.loudness
|
||||
if not hasattr(obj, "dummy") or not obj.dummy:
|
||||
resample_fn = recons.resample
|
||||
loudness_fn = recons.loudness
|
||||
|
||||
# If audio is > 10 minutes long, use the ffmpeg versions
|
||||
if recons.signal_duration >= 10 * 60 * 60:
|
||||
resample_fn = recons.ffmpeg_resample
|
||||
loudness_fn = recons.ffmpeg_loudness
|
||||
# If audio is > 10 minutes long, use the ffmpeg versions
|
||||
if recons.signal_duration >= 10 * 60 * 60:
|
||||
resample_fn = recons.ffmpeg_resample
|
||||
loudness_fn = recons.ffmpeg_loudness
|
||||
|
||||
recons.normalize(obj.input_db)
|
||||
resample_fn(obj.sample_rate)
|
||||
recons = recons[..., : obj.original_length]
|
||||
loudness_fn()
|
||||
recons.audio_data = recons.audio_data.reshape(
|
||||
-1, obj.channels, obj.original_length
|
||||
)
|
||||
recons.normalize(obj.input_db)
|
||||
resample_fn(obj.sample_rate)
|
||||
recons = recons[..., : obj.original_length]
|
||||
loudness_fn()
|
||||
recons.audio_data = recons.audio_data.reshape(
|
||||
-1, obj.channels, obj.original_length
|
||||
)
|
||||
self.padding = original_padding
|
||||
return recons
|
||||
|
||||
|
@ -89,7 +88,7 @@ try:
|
|||
|
||||
except Exception as e:
|
||||
cfg.inference.use_dac = False
|
||||
|
||||
print(str(e))
|
||||
@cache
|
||||
def _load_encodec_model(device="cuda", levels=cfg.model.max_levels):
|
||||
assert cfg.sample_rate == 24_000
|
||||
|
@ -164,7 +163,11 @@ def _load_dac_model(device="cuda", levels=cfg.model.max_levels):
|
|||
model = model.eval()
|
||||
|
||||
# extra metadata
|
||||
model.sample_rate = cfg.sample_rate
|
||||
|
||||
# since DAC moreso models against waveforms, we can actually use a smaller sample rate
|
||||
# updating it here will affect the sample rate the waveform is resampled to on encoding
|
||||
if cfg.variable_sample_rate:
|
||||
model.sample_rate = cfg.sample_rate
|
||||
model.backend = "dac"
|
||||
|
||||
return model
|
||||
|
@ -205,9 +208,10 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
|
|||
|
||||
# DAC uses a different pathway
|
||||
if model.backend == "dac":
|
||||
dummy = False
|
||||
if metadata is None:
|
||||
metadata = dict(
|
||||
chunk_length=416,
|
||||
chunk_length=120,
|
||||
original_length=0,
|
||||
input_db=-12,
|
||||
channels=1,
|
||||
|
@ -215,6 +219,7 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
|
|||
padding=False,
|
||||
dac_version='1.0.0',
|
||||
)
|
||||
dummy = True
|
||||
# generate object with copied metadata
|
||||
artifact = DACFile(
|
||||
codes = codes,
|
||||
|
@ -227,8 +232,10 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
|
|||
padding = metadata["padding"] if isinstance(metadata, dict) else metadata.padding,
|
||||
dac_version = metadata["dac_version"] if isinstance(metadata, dict) else metadata.dac_version,
|
||||
)
|
||||
artifact.dummy = dummy
|
||||
|
||||
return model.decompress(artifact, verbose=False).audio_data[0], artifact.sample_rate
|
||||
# to-do: inject the sample rate encoded at, because we can actually decouple
|
||||
return CodecMixin_decompress(model, artifact, verbose=False).audio_data[0], artifact.sample_rate
|
||||
|
||||
|
||||
kwargs = {}
|
||||
|
|
Loading…
Reference in New Issue
Block a user