From 97c5241bef3cb1642a1393cb475384607e772a01 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 2 Aug 2024 22:25:49 -0500 Subject: [PATCH] fixes, throw an exception when using NAR only model with non-unified position IDs, since for some reason it outputs garbage for the NAR --- vall_e/data.py | 6 ++++-- vall_e/models/ar_nar.py | 2 ++ vall_e/models/base.py | 7 ++++++- vall_e/models/nar.py | 2 +- 4 files changed, 13 insertions(+), 4 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index fccc13d..16e23d3 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -808,7 +808,8 @@ class Dataset(_Dataset): if key not in cfg.hdf5: raise RuntimeError(f'Key of Path ({path}) not in HDF5: {key}') - metadata = cfg.hdf5[key].attrs + #metadata = cfg.hdf5[key].attrs + metadata = { f'{k}': f'{v}' for k, v in cfg.hdf5[key].attrs.items() } text = cfg.hdf5[key]["text"][:] resps = cfg.hdf5[key]["audio"][:, :] @@ -915,7 +916,8 @@ class Dataset(_Dataset): if key not in cfg.hdf5: raise RuntimeError(f'Key of Path ({path}) not in HDF5: {key}') - metadata = cfg.hdf5[key].attrs + # I need to do some weird coersion to a normal dict because it'll bitch about Hdf5 objects not being pickleable in worker processes + metadata = { f'{k}': f'{v}' for k, v in cfg.hdf5[key].attrs.items() } text = cfg.hdf5[key]["text"][:] resps = cfg.hdf5[key]["audio"][:, :] diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 6570d0a..95be6c2 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -536,8 +536,10 @@ def example_usage(): engines = Engines({"ar+nar": engine}) engines.setup() + """ if cfg.optimizations.model_offloading: model = ml.offload_model( model, policy=cfg.optimizations.model_offloading ) + """ """ torch.save( { diff --git a/vall_e/models/base.py b/vall_e/models/base.py index e52decb..638bcc3 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -438,6 +438,10 @@ class Base(nn.Module): audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else "" unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True + # there seems to be a problem with the NAR-only model with non-unified position IDs............. + if "len" in self.capabilities and not unified_position_ids: + raise Exception("ERROR: model instability for NAR-only model when not using unified position IDs.") + self.unified_position_ids = unified_position_ids self.text_emb = Embedding(n_text_tokens, d_model) @@ -1081,11 +1085,12 @@ class Base(nn.Module): if not isinstance(input, torch.Tensor): return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] ) + 1 + # ending input will not have a separator later return input.shape[0] + (0 if name in ["resp", "len"] else 1) for batch_index, batch_input in enumerate(inputs): batch = torch.cat( [ - torch.Tensor([*range(get_input_token_length(name, input))]) + torch.Tensor([*range(get_input_token_length(name, input))]).to(dtype=torch.int32) for name, input in batch_input if name != "task" ] ) diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 67d9d75..2f08439 100644 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -390,7 +390,7 @@ def example_usage(): """ model = NAR(**kwargs).to(device) - steps = 500 + steps = 250 optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""