diff --git a/vall_e/config.py b/vall_e/config.py index 9b92fbe..becd445 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -363,6 +363,10 @@ class Model: def audio_tokens(self): if isinstance(self.size, dict) and hasattr(self.size, "audio_tokens"): return self.size['audio_tokens'] + + if cfg.audio_backend == "nemo": + return 1000 + return 1024 @property diff --git a/vall_e/data.py b/vall_e/data.py index 550d4e2..fcfb439 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -1250,21 +1250,23 @@ class Dataset(_Dataset): trim_length = 0 for _ in range(cfg.dataset.prompt_max_samples): + # yuck + path = None if reference is not None: - # yuck - path = None if random.random() < cfg.dataset.prompt_similar_p: try: path = self.get_similar_utterance( reference, offset = len(prom_list) ) except Exception as e: path = None - if not path: - path = random.choice(choices) - else: + + if not path: path = random.choice(choices) if cfg.dataset.use_hdf5: key = _get_hdf5_path(path) + if key not in cfg.hdf5: + _logger.warning(f'Key of Path ({path}) not in HDF5: {key}') + continue qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16) else: qnt = _load_artifact(path, return_metadata=False) diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 9948ffa..4ab1315 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -241,20 +241,51 @@ def load_engines(training=True, **model_kwargs): ("classifier.bias", model.n_vocab ), ] + last_embedding_keys = {} + # correcting an oversight + """ if model.config.experimental.split_classifiers and "len" in model.capabilities: len_idx, nar_0_idx = model.classifiers.indices(["len", "NAR:0:0"]) keys.append((f"classifiers.proj.{len_idx}.weight", 11)) keys.append((f"classifiers.proj.{len_idx}.bias", 11)) - keys.append((f"classifiers.proj.{nar_0_idx}.weight", 1024)) - keys.append((f"classifiers.proj.{nar_0_idx}.bias", 1024)) + keys.append((f"classifiers.proj.{nar_0_idx}.weight", model.config.audio_tokens)) + keys.append((f"classifiers.proj.{nar_0_idx}.bias", model.config.audio_tokens)) + """ + + # correcting an oversight + """ + if True: + keys.append((f"classifiers.proj.0.weight", model.config.audio_tokens+1)) + for i in range(1,9): + keys.append((f"classifiers.proj.{i}.weight", model.config.audio_tokens)) + + keys.append((f"resps_emb.embeddings.0.weight", model.config.audio_tokens+1)) + keys.append((f"resps_emb.embeddings.8.weight", model.config.audio_tokens+1)) + + for i in range(1,8): + keys.append((f"resps_emb.embeddings.{i}.weight", model.config.audio_tokens)) + + for i in range(8): + keys.append((f"proms_emb.embeddings.{i}.weight", model.config.audio_tokens)) + + last_embedding_keys = { + "classifiers.proj.0.weight": state["classifiers.proj.0.weight"][-1].clone().detach(), + "resps_emb.embeddings.0.weight": state["resps_emb.embeddings.0.weight"][-1].clone().detach(), + "resps_emb.embeddings.8.weight": state["resps_emb.embeddings.8.weight"][-1].clone().detach(), + } + """ + for k, tokens in keys: if k not in state: continue state[k] = ml.resize_weight( state[k], tokens ) + for k, v in last_embedding_keys.items(): + state[k][-1] = v + # stuff to inject new layers into an existing model train over (not recommended, it doesnt amount to anything) """ if True: