fixes fixes fixes (a quarter of my recently processed audio returned zero'd tensors......)
This commit is contained in:
parent
50506e5ebc
commit
ab0abd2b12
|
@ -168,6 +168,7 @@ class Dataset:
|
||||||
use_metadata: bool = False # use genretaed metadata to aid in dataset loading
|
use_metadata: bool = False # use genretaed metadata to aid in dataset loading
|
||||||
|
|
||||||
validate: bool = True # validate each utterance on wheter it can be included based on duration range caps
|
validate: bool = True # validate each utterance on wheter it can be included based on duration range caps
|
||||||
|
strict_validate: bool = False # so far only governs if a path actually exists within the dataset, as this can be a bit slow (and shouldn't really happen normally)
|
||||||
workers: int = 8 # number of dataloader workers to spawn
|
workers: int = 8 # number of dataloader workers to spawn
|
||||||
cache: bool = True # use diskcache to cache the dataset
|
cache: bool = True # use diskcache to cache the dataset
|
||||||
|
|
||||||
|
@ -269,6 +270,11 @@ class ModelExperimentalSettings:
|
||||||
classifiers_bias: bool = True # base LLaMAs do not bias the output heads, but my existing weights do
|
classifiers_bias: bool = True # base LLaMAs do not bias the output heads, but my existing weights do
|
||||||
max_position_embeddings: int = 70 * 65 * 5 # 5 minutes of audio
|
max_position_embeddings: int = 70 * 65 * 5 # 5 minutes of audio
|
||||||
|
|
||||||
|
resp_parallel_training: bool = True # used for version >= 7, computes loss for ALL quant levels rather than the randomly selected one
|
||||||
|
# this should allow for "faster" training as each sample is trained entirely, but slower backwards (and possibly less stable training, maybe)
|
||||||
|
monolithic_audio_encoder: bool = False # combines the prom/resp embeddings into one unit
|
||||||
|
# this usually sounds bad, as the model can "extract" features from the prom separate from the ones in the resp
|
||||||
|
|
||||||
# these technically should be as hyperparameters
|
# these technically should be as hyperparameters
|
||||||
# performs token dropout to compensate for errors
|
# performs token dropout to compensate for errors
|
||||||
token_dropout_error: float = 0.0 # probability to nudge a token by ±1
|
token_dropout_error: float = 0.0 # probability to nudge a token by ±1
|
||||||
|
|
|
@ -765,10 +765,12 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
|
||||||
|
|
||||||
# double check if in HDF5
|
# double check if in HDF5
|
||||||
# this might be slow
|
# this might be slow
|
||||||
"""
|
if cfg.dataset.strict_validate:
|
||||||
if cfg.dataset.use_hdf5 and k not in cfg.hdf5:
|
if cfg.dataset.use_hdf5:
|
||||||
return False
|
if k not in cfg.hdf5:
|
||||||
"""
|
return False
|
||||||
|
elif not (data_dir / id).with_suffix(_get_artifact_extension()).exists():
|
||||||
|
return False
|
||||||
|
|
||||||
# add to duration bucket
|
# add to duration bucket
|
||||||
if type not in _durations_map:
|
if type not in _durations_map:
|
||||||
|
@ -882,7 +884,6 @@ class Dataset(_Dataset):
|
||||||
self.duration_map = _get_duration_map( self.dataset_type )
|
self.duration_map = _get_duration_map( self.dataset_type )
|
||||||
|
|
||||||
# cull speakers if they do not have enough utterances (or cull speakers with too many utternaces)
|
# cull speakers if they do not have enough utterances (or cull speakers with too many utternaces)
|
||||||
"""
|
|
||||||
if cfg.dataset.min_utterances > 0 or cfg.dataset.max_utterances > 0:
|
if cfg.dataset.min_utterances > 0 or cfg.dataset.max_utterances > 0:
|
||||||
keys = list(self.paths_by_spkr_name.keys())
|
keys = list(self.paths_by_spkr_name.keys())
|
||||||
for key in keys:
|
for key in keys:
|
||||||
|
@ -893,7 +894,7 @@ class Dataset(_Dataset):
|
||||||
# slice away extraneous utterances
|
# slice away extraneous utterances
|
||||||
if cfg.dataset.max_utterances:
|
if cfg.dataset.max_utterances:
|
||||||
self.paths_by_spkr_name[key] = self.paths_by_spkr_name[key][:cfg.dataset.max_utterances]
|
self.paths_by_spkr_name[key] = self.paths_by_spkr_name[key][:cfg.dataset.max_utterances]
|
||||||
"""
|
|
||||||
# flatten paths
|
# flatten paths
|
||||||
self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values()))
|
self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values()))
|
||||||
|
|
||||||
|
@ -1272,7 +1273,11 @@ class Dataset(_Dataset):
|
||||||
continue
|
continue
|
||||||
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
|
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
|
||||||
else:
|
else:
|
||||||
qnt = _load_artifact(path, return_metadata=False)
|
try:
|
||||||
|
qnt = _load_artifact(path, return_metadata=False)
|
||||||
|
except Exception as e:
|
||||||
|
_logger.warning(f'Failed to load artifact: {path} ({e})')
|
||||||
|
path = None
|
||||||
|
|
||||||
if 0 < trim_length and trim_length < qnt.shape[0]:
|
if 0 < trim_length and trim_length < qnt.shape[0]:
|
||||||
qnt = trim( qnt, trim_length, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device )
|
qnt = trim( qnt, trim_length, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device )
|
||||||
|
|
|
@ -21,6 +21,7 @@ from ..config import cfg
|
||||||
# need to validate if this is safe to import before modifying the config
|
# need to validate if this is safe to import before modifying the config
|
||||||
from .g2p import encode as phonemize
|
from .g2p import encode as phonemize
|
||||||
from .qnt import encode as quantize, encode_batch as quantize_batch
|
from .qnt import encode as quantize, encode_batch as quantize_batch
|
||||||
|
from ..data import _load_artifact
|
||||||
|
|
||||||
def pad(num, zeroes):
|
def pad(num, zeroes):
|
||||||
return str(num).zfill(zeroes+1)
|
return str(num).zfill(zeroes+1)
|
||||||
|
@ -44,6 +45,10 @@ def process_job( outpath, waveform, sample_rate, text=None, language="en", devic
|
||||||
# encodec requires this to be on CPU for resampling
|
# encodec requires this to be on CPU for resampling
|
||||||
qnt = quantize(waveform, sr=sample_rate, device=device, dtype=dtype)
|
qnt = quantize(waveform, sr=sample_rate, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if torch.count_nonzero(qnt) == 0:
|
||||||
|
tqdm.write(f"Quantization returned zero'd tensor: {outpath}")
|
||||||
|
return
|
||||||
|
|
||||||
if cfg.audio_backend == "dac":
|
if cfg.audio_backend == "dac":
|
||||||
state_dict = {
|
state_dict = {
|
||||||
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
||||||
|
@ -106,6 +111,10 @@ def process_batched_jobs( jobs, speaker_id="", device=None, raise_exceptions=Tru
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for (outpath, waveform, sample_rate, text, language), qnt in zip( batch, codes ):
|
for (outpath, waveform, sample_rate, text, language), qnt in zip( batch, codes ):
|
||||||
|
if torch.count_nonzero(qnt) == 0:
|
||||||
|
tqdm.write(f"Quantization returned zero'd tensor: {outpath}")
|
||||||
|
continue
|
||||||
|
|
||||||
if cfg.audio_backend == "dac":
|
if cfg.audio_backend == "dac":
|
||||||
state_dict = {
|
state_dict = {
|
||||||
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
||||||
|
@ -165,6 +174,7 @@ def process(
|
||||||
output_dataset="training",
|
output_dataset="training",
|
||||||
transcription_filename="whisper.json",
|
transcription_filename="whisper.json",
|
||||||
raise_exceptions=False,
|
raise_exceptions=False,
|
||||||
|
verify_audio=False,
|
||||||
stride=0,
|
stride=0,
|
||||||
stride_offset=0,
|
stride_offset=0,
|
||||||
slice="auto",
|
slice="auto",
|
||||||
|
@ -353,7 +363,13 @@ def process(
|
||||||
text = segment["text"]
|
text = segment["text"]
|
||||||
|
|
||||||
if len(text) == 0 or outpath.exists():
|
if len(text) == 0 or outpath.exists():
|
||||||
continue
|
if not verify_audio:
|
||||||
|
continue
|
||||||
|
|
||||||
|
artifact = _load_artifact( outpath )
|
||||||
|
if torch.count_nonzero(artifact) > 0:
|
||||||
|
continue
|
||||||
|
tqdm.write(f"Found zero'd quantized audio tensor: {outpath}")
|
||||||
|
|
||||||
start = (segment['start']-0.05)
|
start = (segment['start']-0.05)
|
||||||
end = (segment['end']+0.5)
|
end = (segment['end']+0.5)
|
||||||
|
@ -398,6 +414,7 @@ def main():
|
||||||
parser.add_argument("--output-dataset", type=str, default="training/dataset")
|
parser.add_argument("--output-dataset", type=str, default="training/dataset")
|
||||||
parser.add_argument("--transcription-filename", type=str, default="whisper.json")
|
parser.add_argument("--transcription-filename", type=str, default="whisper.json")
|
||||||
parser.add_argument("--raise-exceptions", action="store_true")
|
parser.add_argument("--raise-exceptions", action="store_true")
|
||||||
|
parser.add_argument("--verify-audio", action="store_true")
|
||||||
#parser.add_argument("--low-memory", action="store_true")
|
#parser.add_argument("--low-memory", action="store_true")
|
||||||
parser.add_argument("--skip-existing-folders", action="store_true")
|
parser.add_argument("--skip-existing-folders", action="store_true")
|
||||||
parser.add_argument("--strict-languages", action="store_true")
|
parser.add_argument("--strict-languages", action="store_true")
|
||||||
|
@ -435,6 +452,7 @@ def main():
|
||||||
output_dataset=args.output_dataset,
|
output_dataset=args.output_dataset,
|
||||||
transcription_filename=args.transcription_filename,
|
transcription_filename=args.transcription_filename,
|
||||||
raise_exceptions=args.raise_exceptions,
|
raise_exceptions=args.raise_exceptions,
|
||||||
|
verify_audio=args.verify_audio,
|
||||||
stride=args.stride,
|
stride=args.stride,
|
||||||
stride_offset=args.stride_offset,
|
stride_offset=args.stride_offset,
|
||||||
slice=args.slice,
|
slice=args.slice,
|
||||||
|
@ -453,4 +471,4 @@ def main():
|
||||||
)
|
)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -403,9 +403,11 @@ def load_engines(training=True, **model_kwargs):
|
||||||
kwargs["group"] = "DDP"
|
kwargs["group"] = "DDP"
|
||||||
kwargs['id'] = f'{key_name}-{salt}-{global_rank()}'
|
kwargs['id'] = f'{key_name}-{salt}-{global_rank()}'
|
||||||
|
|
||||||
|
try:
|
||||||
engine.wandb = wandb.init(project=key_name, **kwargs)
|
engine.wandb = wandb.init(project=key_name, **kwargs)
|
||||||
engine.wandb.watch(engine.module)
|
engine.wandb.watch(engine.module)
|
||||||
|
except Exception as e:
|
||||||
|
engine.wandb = None
|
||||||
else:
|
else:
|
||||||
engine.wandb = None
|
engine.wandb = None
|
||||||
|
|
||||||
|
|
|
@ -579,6 +579,10 @@ class Base(nn.Module):
|
||||||
masking_ratio = self.config.experimental.masking_ratio if self.config is not None else False
|
masking_ratio = self.config.experimental.masking_ratio if self.config is not None else False
|
||||||
ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False
|
ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False
|
||||||
|
|
||||||
|
resp_parallel_training = self.config.experimental.resp_parallel_training if self.config is not None else True
|
||||||
|
monolithic_audio_encoder = self.config.experimental.monolithic_audio_encoder if self.config is not None else False
|
||||||
|
|
||||||
|
self.resp_parallel_training = resp_parallel_training
|
||||||
|
|
||||||
n_tasks = self.config.tasks if self.config is not None else 8
|
n_tasks = self.config.tasks if self.config is not None else 8
|
||||||
n_langs = self.config.langs if self.config is not None else 2
|
n_langs = self.config.langs if self.config is not None else 2
|
||||||
|
@ -708,10 +712,8 @@ class Base(nn.Module):
|
||||||
if self.version >= 6:
|
if self.version >= 6:
|
||||||
self.raw_text_emb = Embedding(self.n_raw_text_tokens, d_model)
|
self.raw_text_emb = Embedding(self.n_raw_text_tokens, d_model)
|
||||||
|
|
||||||
self.resp_parallel_training = True # governs if all levels are trained in parallel or one per sample like the old way
|
|
||||||
self.monolithic_audio_encoder = False # monolithic sounds bad
|
|
||||||
if self.version >= 7:
|
if self.version >= 7:
|
||||||
if self.monolithic_audio_encoder:
|
if monolithic_audio_encoder:
|
||||||
self.audio_emb = AudioEncoder(
|
self.audio_emb = AudioEncoder(
|
||||||
n_tokens=n_audio_tokens + 1, # masked token
|
n_tokens=n_audio_tokens + 1, # masked token
|
||||||
n_levels=self.n_resp_levels,
|
n_levels=self.n_resp_levels,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user