diff --git a/test.wav b/test.wav new file mode 100644 index 0000000..e6b0b7c Binary files /dev/null and b/test.wav differ diff --git a/vall_e/config.py b/vall_e/config.py index 2359244..7816bd4 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -168,6 +168,7 @@ class Dataset: use_metadata: bool = False # use genretaed metadata to aid in dataset loading validate: bool = True # validate each utterance on wheter it can be included based on duration range caps + strict_validate: bool = False # so far only governs if a path actually exists within the dataset, as this can be a bit slow (and shouldn't really happen normally) workers: int = 8 # number of dataloader workers to spawn cache: bool = True # use diskcache to cache the dataset @@ -269,6 +270,11 @@ class ModelExperimentalSettings: classifiers_bias: bool = True # base LLaMAs do not bias the output heads, but my existing weights do max_position_embeddings: int = 70 * 65 * 5 # 5 minutes of audio + resp_parallel_training: bool = True # used for version >= 7, computes loss for ALL quant levels rather than the randomly selected one + # this should allow for "faster" training as each sample is trained entirely, but slower backwards (and possibly less stable training, maybe) + monolithic_audio_encoder: bool = False # combines the prom/resp embeddings into one unit + # this usually sounds bad, as the model can "extract" features from the prom separate from the ones in the resp + # these technically should be as hyperparameters # performs token dropout to compensate for errors token_dropout_error: float = 0.0 # probability to nudge a token by ±1 diff --git a/vall_e/data.py b/vall_e/data.py index 45fb3f9..81424fd 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -765,10 +765,12 @@ def _load_paths_from_metadata(group_name, type="training", validate=False): # double check if in HDF5 # this might be slow - """ - if cfg.dataset.use_hdf5 and k not in cfg.hdf5: - return False - """ + if cfg.dataset.strict_validate: + if cfg.dataset.use_hdf5: + if k not in cfg.hdf5: + return False + elif not (data_dir / id).with_suffix(_get_artifact_extension()).exists(): + return False # add to duration bucket if type not in _durations_map: @@ -882,7 +884,6 @@ class Dataset(_Dataset): self.duration_map = _get_duration_map( self.dataset_type ) # cull speakers if they do not have enough utterances (or cull speakers with too many utternaces) - """ if cfg.dataset.min_utterances > 0 or cfg.dataset.max_utterances > 0: keys = list(self.paths_by_spkr_name.keys()) for key in keys: @@ -893,7 +894,7 @@ class Dataset(_Dataset): # slice away extraneous utterances if cfg.dataset.max_utterances: self.paths_by_spkr_name[key] = self.paths_by_spkr_name[key][:cfg.dataset.max_utterances] - """ + # flatten paths self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values())) @@ -1272,7 +1273,11 @@ class Dataset(_Dataset): continue qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16) else: - qnt = _load_artifact(path, return_metadata=False) + try: + qnt = _load_artifact(path, return_metadata=False) + except Exception as e: + _logger.warning(f'Failed to load artifact: {path} ({e})') + path = None if 0 < trim_length and trim_length < qnt.shape[0]: qnt = trim( qnt, trim_length, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device ) diff --git a/vall_e/emb/process.py b/vall_e/emb/process.py index 19fb9df..427c8e4 100644 --- a/vall_e/emb/process.py +++ b/vall_e/emb/process.py @@ -21,6 +21,7 @@ from ..config import cfg # need to validate if this is safe to import before modifying the config from .g2p import encode as phonemize from .qnt import encode as quantize, encode_batch as quantize_batch +from ..data import _load_artifact def pad(num, zeroes): return str(num).zfill(zeroes+1) @@ -44,6 +45,10 @@ def process_job( outpath, waveform, sample_rate, text=None, language="en", devic # encodec requires this to be on CPU for resampling qnt = quantize(waveform, sr=sample_rate, device=device, dtype=dtype) + if torch.count_nonzero(qnt) == 0: + tqdm.write(f"Quantization returned zero'd tensor: {outpath}") + return + if cfg.audio_backend == "dac": state_dict = { "codes": qnt.codes.cpu().numpy().astype(np.uint16), @@ -106,6 +111,10 @@ def process_batched_jobs( jobs, speaker_id="", device=None, raise_exceptions=Tru continue for (outpath, waveform, sample_rate, text, language), qnt in zip( batch, codes ): + if torch.count_nonzero(qnt) == 0: + tqdm.write(f"Quantization returned zero'd tensor: {outpath}") + continue + if cfg.audio_backend == "dac": state_dict = { "codes": qnt.codes.cpu().numpy().astype(np.uint16), @@ -165,6 +174,7 @@ def process( output_dataset="training", transcription_filename="whisper.json", raise_exceptions=False, + verify_audio=False, stride=0, stride_offset=0, slice="auto", @@ -353,7 +363,13 @@ def process( text = segment["text"] if len(text) == 0 or outpath.exists(): - continue + if not verify_audio: + continue + + artifact = _load_artifact( outpath ) + if torch.count_nonzero(artifact) > 0: + continue + tqdm.write(f"Found zero'd quantized audio tensor: {outpath}") start = (segment['start']-0.05) end = (segment['end']+0.5) @@ -398,6 +414,7 @@ def main(): parser.add_argument("--output-dataset", type=str, default="training/dataset") parser.add_argument("--transcription-filename", type=str, default="whisper.json") parser.add_argument("--raise-exceptions", action="store_true") + parser.add_argument("--verify-audio", action="store_true") #parser.add_argument("--low-memory", action="store_true") parser.add_argument("--skip-existing-folders", action="store_true") parser.add_argument("--strict-languages", action="store_true") @@ -435,6 +452,7 @@ def main(): output_dataset=args.output_dataset, transcription_filename=args.transcription_filename, raise_exceptions=args.raise_exceptions, + verify_audio=args.verify_audio, stride=args.stride, stride_offset=args.stride_offset, slice=args.slice, @@ -453,4 +471,4 @@ def main(): ) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 8d9c133..133b7f1 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -403,9 +403,11 @@ def load_engines(training=True, **model_kwargs): kwargs["group"] = "DDP" kwargs['id'] = f'{key_name}-{salt}-{global_rank()}' - - engine.wandb = wandb.init(project=key_name, **kwargs) - engine.wandb.watch(engine.module) + try: + engine.wandb = wandb.init(project=key_name, **kwargs) + engine.wandb.watch(engine.module) + except Exception as e: + engine.wandb = None else: engine.wandb = None diff --git a/vall_e/models/base.py b/vall_e/models/base.py index e268752..21ccdcf 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -579,6 +579,10 @@ class Base(nn.Module): masking_ratio = self.config.experimental.masking_ratio if self.config is not None else False ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False + resp_parallel_training = self.config.experimental.resp_parallel_training if self.config is not None else True + monolithic_audio_encoder = self.config.experimental.monolithic_audio_encoder if self.config is not None else False + + self.resp_parallel_training = resp_parallel_training n_tasks = self.config.tasks if self.config is not None else 8 n_langs = self.config.langs if self.config is not None else 2 @@ -708,10 +712,8 @@ class Base(nn.Module): if self.version >= 6: self.raw_text_emb = Embedding(self.n_raw_text_tokens, d_model) - self.resp_parallel_training = True # governs if all levels are trained in parallel or one per sample like the old way - self.monolithic_audio_encoder = False # monolithic sounds bad if self.version >= 7: - if self.monolithic_audio_encoder: + if monolithic_audio_encoder: self.audio_emb = AudioEncoder( n_tokens=n_audio_tokens + 1, # masked token n_levels=self.n_resp_levels,