From 35510c954c4668817626ad01bbcfed00dbffb9f0 Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 13 Feb 2025 16:47:59 -0600 Subject: [PATCH] ugh --- vall_e/models/base.py | 12 ++++++------ vall_e/utils/wrapper.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index acc70e8..c68f1ee 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -135,7 +135,7 @@ def _interleave_sequence_flatten( input: list[torch.Tensor] ): # automagically parses a batch-list and returns it as a list """ -class Embedding(nn.Embedding): +class Embedding(ml.Embedding): def forward(self, x_list: list[Tensor]) -> list[Tensor]: if len(x_list) == 0: return [] @@ -192,7 +192,7 @@ class AudioEmbedding_Old(nn.Module): # array of embeddings # proms are [0, resp_levels] # resp are split to where [0] is for the AR, and [1:] are reserved for NAR - self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_embedding_tokens]) + self.embeddings = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for n_tokens in l_embedding_tokens]) # weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this) self.weight = nn.ParameterList([nn.Parameter( torch.tensor([1]) ) for i in range(levels)]) if levels is not None else None @@ -223,7 +223,7 @@ class AudioEmbedding(nn.Module): # array of embeddings # proms are [0, resp_levels] # resp are split to where [0] is for the AR, and [1:] are reserved for NAR - self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_embedding_tokens]) + self.embeddings = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for n_tokens in l_embedding_tokens]) # further experimentation is needed to see if this actually is useful self.sums = sums # @@ -350,7 +350,7 @@ class AudioEncoder(nn.Module): token_dim: int, ): super().__init__() - self.embs = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for l in range(n_levels)]) + self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)]) self.proj = nn.Linear(8 * token_dim, 1 * token_dim) def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor: @@ -2201,7 +2201,7 @@ if __name__ == "__main__": if is_from_pretrained: n_vocab = end - start - embds[k] = torch.nn.Embedding( n_vocab, n_embd ).to(model.embed_tokens.weight) + embds[k] = torch.ml.Embedding( n_vocab, n_embd ).to(model.embed_tokens.weight) embds[k].weight[:] = model.embed_tokens.weight[start:end, :] if classifier_idx >= 0: @@ -2213,7 +2213,7 @@ if __name__ == "__main__": heads[k].weight[:] = hf_model.lm_head.weight[start:end, :] else: embd_weight = state_dict[embd_name].unsqueeze(0) if state_dict[embd_name].dim() == 1 else state_dict[embd_name] - embds[k] = torch.nn.Embedding( embd_weight.shape[0], embd_weight.shape[1] ).to(device=device, dtype=dtype) + embds[k] = torch.ml.Embedding( embd_weight.shape[0], embd_weight.shape[1] ).to(device=device, dtype=dtype) embds[k].load_state_dict({ "weight": embd_weight }) if classifier_idx >= 0: diff --git a/vall_e/utils/wrapper.py b/vall_e/utils/wrapper.py index cb40e9b..5c86ede 100755 --- a/vall_e/utils/wrapper.py +++ b/vall_e/utils/wrapper.py @@ -32,7 +32,7 @@ if cfg.optimizations.bitsandbytes: Linear = bnb.nn.Linear8bitLt if cfg.optimizations.embedding: - Embedding = bnb.nn.modules.Embedding + Embedding = bnb.nn.StableEmbedding """ Embedding.forward = lambda self, input: ( self.norm(F.embedding( input,