fixes fixes fixes (a quarter of my recently processed audio returned zero'd tensors......)

This commit is contained in:
mrq 2025-02-22 09:07:33 -06:00
parent 50506e5ebc
commit ab0abd2b12
6 changed files with 48 additions and 15 deletions

BIN
test.wav Normal file

Binary file not shown.

View File

@ -168,6 +168,7 @@ class Dataset:
use_metadata: bool = False # use genretaed metadata to aid in dataset loading
validate: bool = True # validate each utterance on wheter it can be included based on duration range caps
strict_validate: bool = False # so far only governs if a path actually exists within the dataset, as this can be a bit slow (and shouldn't really happen normally)
workers: int = 8 # number of dataloader workers to spawn
cache: bool = True # use diskcache to cache the dataset
@ -269,6 +270,11 @@ class ModelExperimentalSettings:
classifiers_bias: bool = True # base LLaMAs do not bias the output heads, but my existing weights do
max_position_embeddings: int = 70 * 65 * 5 # 5 minutes of audio
resp_parallel_training: bool = True # used for version >= 7, computes loss for ALL quant levels rather than the randomly selected one
# this should allow for "faster" training as each sample is trained entirely, but slower backwards (and possibly less stable training, maybe)
monolithic_audio_encoder: bool = False # combines the prom/resp embeddings into one unit
# this usually sounds bad, as the model can "extract" features from the prom separate from the ones in the resp
# these technically should be as hyperparameters
# performs token dropout to compensate for errors
token_dropout_error: float = 0.0 # probability to nudge a token by ±1

View File

@ -765,10 +765,12 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
# double check if in HDF5
# this might be slow
"""
if cfg.dataset.use_hdf5 and k not in cfg.hdf5:
return False
"""
if cfg.dataset.strict_validate:
if cfg.dataset.use_hdf5:
if k not in cfg.hdf5:
return False
elif not (data_dir / id).with_suffix(_get_artifact_extension()).exists():
return False
# add to duration bucket
if type not in _durations_map:
@ -882,7 +884,6 @@ class Dataset(_Dataset):
self.duration_map = _get_duration_map( self.dataset_type )
# cull speakers if they do not have enough utterances (or cull speakers with too many utternaces)
"""
if cfg.dataset.min_utterances > 0 or cfg.dataset.max_utterances > 0:
keys = list(self.paths_by_spkr_name.keys())
for key in keys:
@ -893,7 +894,7 @@ class Dataset(_Dataset):
# slice away extraneous utterances
if cfg.dataset.max_utterances:
self.paths_by_spkr_name[key] = self.paths_by_spkr_name[key][:cfg.dataset.max_utterances]
"""
# flatten paths
self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values()))
@ -1272,7 +1273,11 @@ class Dataset(_Dataset):
continue
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
else:
qnt = _load_artifact(path, return_metadata=False)
try:
qnt = _load_artifact(path, return_metadata=False)
except Exception as e:
_logger.warning(f'Failed to load artifact: {path} ({e})')
path = None
if 0 < trim_length and trim_length < qnt.shape[0]:
qnt = trim( qnt, trim_length, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device )

View File

@ -21,6 +21,7 @@ from ..config import cfg
# need to validate if this is safe to import before modifying the config
from .g2p import encode as phonemize
from .qnt import encode as quantize, encode_batch as quantize_batch
from ..data import _load_artifact
def pad(num, zeroes):
return str(num).zfill(zeroes+1)
@ -44,6 +45,10 @@ def process_job( outpath, waveform, sample_rate, text=None, language="en", devic
# encodec requires this to be on CPU for resampling
qnt = quantize(waveform, sr=sample_rate, device=device, dtype=dtype)
if torch.count_nonzero(qnt) == 0:
tqdm.write(f"Quantization returned zero'd tensor: {outpath}")
return
if cfg.audio_backend == "dac":
state_dict = {
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
@ -106,6 +111,10 @@ def process_batched_jobs( jobs, speaker_id="", device=None, raise_exceptions=Tru
continue
for (outpath, waveform, sample_rate, text, language), qnt in zip( batch, codes ):
if torch.count_nonzero(qnt) == 0:
tqdm.write(f"Quantization returned zero'd tensor: {outpath}")
continue
if cfg.audio_backend == "dac":
state_dict = {
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
@ -165,6 +174,7 @@ def process(
output_dataset="training",
transcription_filename="whisper.json",
raise_exceptions=False,
verify_audio=False,
stride=0,
stride_offset=0,
slice="auto",
@ -353,7 +363,13 @@ def process(
text = segment["text"]
if len(text) == 0 or outpath.exists():
continue
if not verify_audio:
continue
artifact = _load_artifact( outpath )
if torch.count_nonzero(artifact) > 0:
continue
tqdm.write(f"Found zero'd quantized audio tensor: {outpath}")
start = (segment['start']-0.05)
end = (segment['end']+0.5)
@ -398,6 +414,7 @@ def main():
parser.add_argument("--output-dataset", type=str, default="training/dataset")
parser.add_argument("--transcription-filename", type=str, default="whisper.json")
parser.add_argument("--raise-exceptions", action="store_true")
parser.add_argument("--verify-audio", action="store_true")
#parser.add_argument("--low-memory", action="store_true")
parser.add_argument("--skip-existing-folders", action="store_true")
parser.add_argument("--strict-languages", action="store_true")
@ -435,6 +452,7 @@ def main():
output_dataset=args.output_dataset,
transcription_filename=args.transcription_filename,
raise_exceptions=args.raise_exceptions,
verify_audio=args.verify_audio,
stride=args.stride,
stride_offset=args.stride_offset,
slice=args.slice,
@ -453,4 +471,4 @@ def main():
)
if __name__ == "__main__":
main()
main()

View File

@ -403,9 +403,11 @@ def load_engines(training=True, **model_kwargs):
kwargs["group"] = "DDP"
kwargs['id'] = f'{key_name}-{salt}-{global_rank()}'
engine.wandb = wandb.init(project=key_name, **kwargs)
engine.wandb.watch(engine.module)
try:
engine.wandb = wandb.init(project=key_name, **kwargs)
engine.wandb.watch(engine.module)
except Exception as e:
engine.wandb = None
else:
engine.wandb = None

View File

@ -579,6 +579,10 @@ class Base(nn.Module):
masking_ratio = self.config.experimental.masking_ratio if self.config is not None else False
ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False
resp_parallel_training = self.config.experimental.resp_parallel_training if self.config is not None else True
monolithic_audio_encoder = self.config.experimental.monolithic_audio_encoder if self.config is not None else False
self.resp_parallel_training = resp_parallel_training
n_tasks = self.config.tasks if self.config is not None else 8
n_langs = self.config.langs if self.config is not None else 2
@ -708,10 +712,8 @@ class Base(nn.Module):
if self.version >= 6:
self.raw_text_emb = Embedding(self.n_raw_text_tokens, d_model)
self.resp_parallel_training = True # governs if all levels are trained in parallel or one per sample like the old way
self.monolithic_audio_encoder = False # monolithic sounds bad
if self.version >= 7:
if self.monolithic_audio_encoder:
if monolithic_audio_encoder:
self.audio_emb = AudioEncoder(
n_tokens=n_audio_tokens + 1, # masked token
n_levels=self.n_resp_levels,