diff --git a/vall_e/config.py b/vall_e/config.py index e1471eb..50329e1 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -220,7 +220,7 @@ class Model: capabilities: list = field(default_factory=lambda: ["ar", "nar"]) experimental: str | None = None # for now it sets things to be HF compatible kv_heads: int = 0 # MHA or GQA (for supported backends) - use_external_audio_embeddings: bool = False # subjugates the audio backend's encoding/decoding model for embeddings + audio_embeddings_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 36a1f1e..2c6e8c4 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -263,12 +263,13 @@ def decode_to_file(resps: Tensor, path: Path, device="cuda"): def _replace_file_extension(path, suffix): return (path.parent / path.name.split(".")[0]).with_suffix(suffix) -# some experimental shit involving using the Encodec/DAC model's embeddings itself +# an experimental way to include "trained" embeddings from the audio backend itself +# > b-but why not just initialize the embedding weights to these instead of fetching them at r-runtime +# each audio backend does their "embeddings" a different way that isn't just a embedding weights @torch.inference_mode() def encode_as_embedding(codes: Tensor, quant_level: int = 0, device="cpu"): model = _load_model(device) - codes = codes.to(device=device, dtype=torch.int32) if codes.dim() == 1: @@ -277,7 +278,7 @@ def encode_as_embedding(codes: Tensor, quant_level: int = 0, device="cpu"): codes = codes[:, quant_level] codes = rearrange(codes, "t -> 1 t") - + # dac conveniently has its dim = 1024 if cfg.audio_backend == "dac": emb = model.quantizer.quantizers[quant_level] @@ -287,6 +288,15 @@ def encode_as_embedding(codes: Tensor, quant_level: int = 0, device="cpu"): return x + """ + # vocos inconveniently has its dim = 128 + elif cfg.audio_backend == "vocos": + x = model.codes_to_features(codes) + # encodec inconveniently has its dim = 300 + elif cfg.audio_backend == "encodec": + ... + """ + raise Exception(f'Currently only DAC is supported') @torch.inference_mode() diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index ef6a808..2961442 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -107,10 +107,10 @@ class AR_NAR(Base): return True @property - def use_external_audio_embeddings(self) -> bool: + def audio_embeddings_mode(self) -> bool: if hasattr(self, "config") and self.config: - return self.config.use_external_audio_embeddings - return cfg.model.use_external_audio_embeddings + return self.config.audio_embeddings_mode + return cfg.model.audio_embeddings_mode @property def version(self) -> int: @@ -473,7 +473,7 @@ def example_usage(): """ model = AR_NAR(**kwargs).to(device) - steps = 200 if cfg.model.arch_type in ["mamba","mamba2"] else 200 + steps = 100 optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else "" diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 1eb175c..b063148 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -149,7 +149,8 @@ class AudioEmbedding(nn.Module): self, l_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token) token_dim: int, # dimensionality of the embedding - sums: bool = True # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better) + sums: bool = True, # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better) + external_mode: str | None = None, # "exclusive" | "inclusive", whether to include the original audio backend's embeddings ): super().__init__() # array of embeddings @@ -157,10 +158,42 @@ class AudioEmbedding(nn.Module): # resp are split to where [0] is for the AR, and [1:] are reserved for NAR # + resps cannot share the AR and NAR embeddings, since they do encode whether to predict the same level but in the next token or predict in place but the next level self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens]) - # + # further experimentation is needed to see if this actually is useful self.sums = sums - def forward(self, xi: Tensor, offset: int = 0 ) -> Tensor: + self.external_mode = external_mode + + # set initial weights to zero + if self.external_mode == "inclusive": + for i, embedding in enumerate(self.embeddings): + embedding.weight = torch.nn.Parameter(torch.zeros( embedding.weight.shape )) + + def external_embeddings(self, input: Tensor) -> Tensor: + quant_level = 0 if input.dim() == 1 else input.shape[-1] - 1 + + # for AR, trim any stop tokens + has_stop_token = False + if quant_level == 0: + stop_token = self.embeddings[0].weight.shape[0] - 1 + stop_token_indices = (input == stop_token).nonzero() + has_stop_token = len(stop_token_indices) > 0 + + if has_stop_token: + input = input[:stop_token_indices.min().item()] + + # get external embedding + embedding = encode_as_embedding( input, quant_level ).to(device=input.device, dtype=self.embeddings[quant_level].weight.dtype) + # resize if necessary (in case the external embeddings do not match our model dim) + embedding = ml.resize_weight( embedding, self.embeddings[quant_level].weight.shape[-1], dim=-1, random=False ) + + # reintroduce stop token + if has_stop_token: + stop_token = self.internal_forward( torch.Tensor([stop_token]).to(device=input.device, dtype=torch.int16), 0 ) + embedding = torch.concat( [ embedding, stop_token ] ) + + return embedding + + def internal_forward(self, xi: Tensor, offset: int = 0 ) -> Tensor: quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1 if self.sums and quant_level > 0: @@ -171,26 +204,17 @@ class AudioEmbedding(nn.Module): return x -# subjugates the audio backend's embeddings -# inherits for use of the stop token -class AudioEmbedding_External(AudioEmbedding): - def forward(self, input: Tensor, offset: int = 0 ) -> Tensor: - if not input.shape[0]: - return super().forward( input ) + def forward(self, xi: Tensor, offset: int = 0 ) -> Tensor: + x = self.internal_forward( xi, offset ) if self.external_mode != "exclusive" or xi.shape[0] == 0 else None - quant_level = 0 if input.dim() == 1 else input.shape[-1] - 1 - has_stop_token = quant_level == 0 and input[-1] == 1024 + if self.external_mode and xi.shape[0] > 0: + external_embeddings = self.external_embeddings( xi ) + if self.external_mode == "exclusive": + return external_embeddings + x += external_embeddings - if has_stop_token: - input = input[:-1] + return x - embedding = encode_as_embedding( input, quant_level ).to(device=input.device, dtype=self.embeddings[0].weight.dtype) - - if has_stop_token: - stop_token = super().forward( torch.Tensor([1024]).to(device=input.device, dtype=torch.int16), 0 ) - embedding = torch.concat( [ embedding, stop_token ] ) - - return embedding # per-level classification class AudioClassifier(nn.Module): def __init__( @@ -292,8 +316,8 @@ class Base(nn.Module): return False @property - def use_external_audio_embeddings(self) -> bool: - return False + def audio_embeddings_mode(self) -> str | None: + return None @property def version(self) -> int: @@ -392,23 +416,16 @@ class Base(nn.Module): l_tokens, d_model, levels=self.n_resp_levels if self.version > 3 else None, ) - elif self.use_external_audio_embeddings: - self.proms_emb = AudioEmbedding_External( - [n_prom_tokens] * self.n_prom_levels, d_model, - sums=audio_embedding_sums, - ) - self.resps_emb = AudioEmbedding_External( - l_tokens, d_model, - sums=audio_embedding_sums, - ) else: self.proms_emb = AudioEmbedding( [n_prom_tokens] * self.n_prom_levels, d_model, sums=audio_embedding_sums, + external_mode=self.audio_embeddings_mode, ) self.resps_emb = AudioEmbedding( l_tokens, d_model, sums=audio_embedding_sums, + external_mode=self.audio_embeddings_mode, ) # useless since I actually removed using these with the input processing overhaul... diff --git a/vall_e/utils/wrapper.py b/vall_e/utils/wrapper.py index 120763e..c2a6dff 100755 --- a/vall_e/utils/wrapper.py +++ b/vall_e/utils/wrapper.py @@ -213,15 +213,16 @@ def replace_attention( model, klass, target, mode="math", verbose=False ): return model # trim/expand a tensor (for example, in a state dict) -def resize_weight( weight, target ): +def resize_weight( weight, target, dim=0, random=True ): # trim - if target < weight.shape[0]: + if target < weight.shape[dim]: return weight[:target] # expand - if target > weight.shape[0]: + if target > weight.shape[dim]: + fn = torch.rand if random else torch.zeros return torch.stack( [ x for x in weight ] + - [ torch.rand( weight[0].shape ).to(device=weight[0].device, dtype=weight[0].dtype) for _ in range( target - weight.shape[0] ) ] + [ fn( weight[0].shape ).to(device=weight[0].device, dtype=weight[0].dtype) for _ in range( target - weight.shape[dim] ) ] ) return weight