fixes fixes fixes (a quarter of my recently processed audio returned zero'd tensors......)

This commit is contained in:
mrq 2025-02-22 09:07:33 -06:00
parent 50506e5ebc
commit ab0abd2b12
6 changed files with 48 additions and 15 deletions

BIN
test.wav Normal file

Binary file not shown.

View File

@ -168,6 +168,7 @@ class Dataset:
use_metadata: bool = False # use genretaed metadata to aid in dataset loading use_metadata: bool = False # use genretaed metadata to aid in dataset loading
validate: bool = True # validate each utterance on wheter it can be included based on duration range caps validate: bool = True # validate each utterance on wheter it can be included based on duration range caps
strict_validate: bool = False # so far only governs if a path actually exists within the dataset, as this can be a bit slow (and shouldn't really happen normally)
workers: int = 8 # number of dataloader workers to spawn workers: int = 8 # number of dataloader workers to spawn
cache: bool = True # use diskcache to cache the dataset cache: bool = True # use diskcache to cache the dataset
@ -269,6 +270,11 @@ class ModelExperimentalSettings:
classifiers_bias: bool = True # base LLaMAs do not bias the output heads, but my existing weights do classifiers_bias: bool = True # base LLaMAs do not bias the output heads, but my existing weights do
max_position_embeddings: int = 70 * 65 * 5 # 5 minutes of audio max_position_embeddings: int = 70 * 65 * 5 # 5 minutes of audio
resp_parallel_training: bool = True # used for version >= 7, computes loss for ALL quant levels rather than the randomly selected one
# this should allow for "faster" training as each sample is trained entirely, but slower backwards (and possibly less stable training, maybe)
monolithic_audio_encoder: bool = False # combines the prom/resp embeddings into one unit
# this usually sounds bad, as the model can "extract" features from the prom separate from the ones in the resp
# these technically should be as hyperparameters # these technically should be as hyperparameters
# performs token dropout to compensate for errors # performs token dropout to compensate for errors
token_dropout_error: float = 0.0 # probability to nudge a token by ±1 token_dropout_error: float = 0.0 # probability to nudge a token by ±1

View File

@ -765,10 +765,12 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
# double check if in HDF5 # double check if in HDF5
# this might be slow # this might be slow
""" if cfg.dataset.strict_validate:
if cfg.dataset.use_hdf5 and k not in cfg.hdf5: if cfg.dataset.use_hdf5:
return False if k not in cfg.hdf5:
""" return False
elif not (data_dir / id).with_suffix(_get_artifact_extension()).exists():
return False
# add to duration bucket # add to duration bucket
if type not in _durations_map: if type not in _durations_map:
@ -882,7 +884,6 @@ class Dataset(_Dataset):
self.duration_map = _get_duration_map( self.dataset_type ) self.duration_map = _get_duration_map( self.dataset_type )
# cull speakers if they do not have enough utterances (or cull speakers with too many utternaces) # cull speakers if they do not have enough utterances (or cull speakers with too many utternaces)
"""
if cfg.dataset.min_utterances > 0 or cfg.dataset.max_utterances > 0: if cfg.dataset.min_utterances > 0 or cfg.dataset.max_utterances > 0:
keys = list(self.paths_by_spkr_name.keys()) keys = list(self.paths_by_spkr_name.keys())
for key in keys: for key in keys:
@ -893,7 +894,7 @@ class Dataset(_Dataset):
# slice away extraneous utterances # slice away extraneous utterances
if cfg.dataset.max_utterances: if cfg.dataset.max_utterances:
self.paths_by_spkr_name[key] = self.paths_by_spkr_name[key][:cfg.dataset.max_utterances] self.paths_by_spkr_name[key] = self.paths_by_spkr_name[key][:cfg.dataset.max_utterances]
"""
# flatten paths # flatten paths
self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values())) self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values()))
@ -1272,7 +1273,11 @@ class Dataset(_Dataset):
continue continue
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16) qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
else: else:
qnt = _load_artifact(path, return_metadata=False) try:
qnt = _load_artifact(path, return_metadata=False)
except Exception as e:
_logger.warning(f'Failed to load artifact: {path} ({e})')
path = None
if 0 < trim_length and trim_length < qnt.shape[0]: if 0 < trim_length and trim_length < qnt.shape[0]:
qnt = trim( qnt, trim_length, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device ) qnt = trim( qnt, trim_length, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device )

View File

@ -21,6 +21,7 @@ from ..config import cfg
# need to validate if this is safe to import before modifying the config # need to validate if this is safe to import before modifying the config
from .g2p import encode as phonemize from .g2p import encode as phonemize
from .qnt import encode as quantize, encode_batch as quantize_batch from .qnt import encode as quantize, encode_batch as quantize_batch
from ..data import _load_artifact
def pad(num, zeroes): def pad(num, zeroes):
return str(num).zfill(zeroes+1) return str(num).zfill(zeroes+1)
@ -44,6 +45,10 @@ def process_job( outpath, waveform, sample_rate, text=None, language="en", devic
# encodec requires this to be on CPU for resampling # encodec requires this to be on CPU for resampling
qnt = quantize(waveform, sr=sample_rate, device=device, dtype=dtype) qnt = quantize(waveform, sr=sample_rate, device=device, dtype=dtype)
if torch.count_nonzero(qnt) == 0:
tqdm.write(f"Quantization returned zero'd tensor: {outpath}")
return
if cfg.audio_backend == "dac": if cfg.audio_backend == "dac":
state_dict = { state_dict = {
"codes": qnt.codes.cpu().numpy().astype(np.uint16), "codes": qnt.codes.cpu().numpy().astype(np.uint16),
@ -106,6 +111,10 @@ def process_batched_jobs( jobs, speaker_id="", device=None, raise_exceptions=Tru
continue continue
for (outpath, waveform, sample_rate, text, language), qnt in zip( batch, codes ): for (outpath, waveform, sample_rate, text, language), qnt in zip( batch, codes ):
if torch.count_nonzero(qnt) == 0:
tqdm.write(f"Quantization returned zero'd tensor: {outpath}")
continue
if cfg.audio_backend == "dac": if cfg.audio_backend == "dac":
state_dict = { state_dict = {
"codes": qnt.codes.cpu().numpy().astype(np.uint16), "codes": qnt.codes.cpu().numpy().astype(np.uint16),
@ -165,6 +174,7 @@ def process(
output_dataset="training", output_dataset="training",
transcription_filename="whisper.json", transcription_filename="whisper.json",
raise_exceptions=False, raise_exceptions=False,
verify_audio=False,
stride=0, stride=0,
stride_offset=0, stride_offset=0,
slice="auto", slice="auto",
@ -353,7 +363,13 @@ def process(
text = segment["text"] text = segment["text"]
if len(text) == 0 or outpath.exists(): if len(text) == 0 or outpath.exists():
continue if not verify_audio:
continue
artifact = _load_artifact( outpath )
if torch.count_nonzero(artifact) > 0:
continue
tqdm.write(f"Found zero'd quantized audio tensor: {outpath}")
start = (segment['start']-0.05) start = (segment['start']-0.05)
end = (segment['end']+0.5) end = (segment['end']+0.5)
@ -398,6 +414,7 @@ def main():
parser.add_argument("--output-dataset", type=str, default="training/dataset") parser.add_argument("--output-dataset", type=str, default="training/dataset")
parser.add_argument("--transcription-filename", type=str, default="whisper.json") parser.add_argument("--transcription-filename", type=str, default="whisper.json")
parser.add_argument("--raise-exceptions", action="store_true") parser.add_argument("--raise-exceptions", action="store_true")
parser.add_argument("--verify-audio", action="store_true")
#parser.add_argument("--low-memory", action="store_true") #parser.add_argument("--low-memory", action="store_true")
parser.add_argument("--skip-existing-folders", action="store_true") parser.add_argument("--skip-existing-folders", action="store_true")
parser.add_argument("--strict-languages", action="store_true") parser.add_argument("--strict-languages", action="store_true")
@ -435,6 +452,7 @@ def main():
output_dataset=args.output_dataset, output_dataset=args.output_dataset,
transcription_filename=args.transcription_filename, transcription_filename=args.transcription_filename,
raise_exceptions=args.raise_exceptions, raise_exceptions=args.raise_exceptions,
verify_audio=args.verify_audio,
stride=args.stride, stride=args.stride,
stride_offset=args.stride_offset, stride_offset=args.stride_offset,
slice=args.slice, slice=args.slice,
@ -453,4 +471,4 @@ def main():
) )
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -403,9 +403,11 @@ def load_engines(training=True, **model_kwargs):
kwargs["group"] = "DDP" kwargs["group"] = "DDP"
kwargs['id'] = f'{key_name}-{salt}-{global_rank()}' kwargs['id'] = f'{key_name}-{salt}-{global_rank()}'
try:
engine.wandb = wandb.init(project=key_name, **kwargs) engine.wandb = wandb.init(project=key_name, **kwargs)
engine.wandb.watch(engine.module) engine.wandb.watch(engine.module)
except Exception as e:
engine.wandb = None
else: else:
engine.wandb = None engine.wandb = None

View File

@ -579,6 +579,10 @@ class Base(nn.Module):
masking_ratio = self.config.experimental.masking_ratio if self.config is not None else False masking_ratio = self.config.experimental.masking_ratio if self.config is not None else False
ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False
resp_parallel_training = self.config.experimental.resp_parallel_training if self.config is not None else True
monolithic_audio_encoder = self.config.experimental.monolithic_audio_encoder if self.config is not None else False
self.resp_parallel_training = resp_parallel_training
n_tasks = self.config.tasks if self.config is not None else 8 n_tasks = self.config.tasks if self.config is not None else 8
n_langs = self.config.langs if self.config is not None else 2 n_langs = self.config.langs if self.config is not None else 2
@ -708,10 +712,8 @@ class Base(nn.Module):
if self.version >= 6: if self.version >= 6:
self.raw_text_emb = Embedding(self.n_raw_text_tokens, d_model) self.raw_text_emb = Embedding(self.n_raw_text_tokens, d_model)
self.resp_parallel_training = True # governs if all levels are trained in parallel or one per sample like the old way
self.monolithic_audio_encoder = False # monolithic sounds bad
if self.version >= 7: if self.version >= 7:
if self.monolithic_audio_encoder: if monolithic_audio_encoder:
self.audio_emb = AudioEncoder( self.audio_emb = AudioEncoder(
n_tokens=n_audio_tokens + 1, # masked token n_tokens=n_audio_tokens + 1, # masked token
n_levels=self.n_resp_levels, n_levels=self.n_resp_levels,