This commit is contained in:
mrq 2025-02-13 16:47:59 -06:00
parent a3b7260514
commit 35510c954c
2 changed files with 7 additions and 7 deletions

View File

@ -135,7 +135,7 @@ def _interleave_sequence_flatten( input: list[torch.Tensor] ):
# automagically parses a batch-list and returns it as a list
"""
class Embedding(nn.Embedding):
class Embedding(ml.Embedding):
def forward(self, x_list: list[Tensor]) -> list[Tensor]:
if len(x_list) == 0:
return []
@ -192,7 +192,7 @@ class AudioEmbedding_Old(nn.Module):
# array of embeddings
# proms are [0, resp_levels]
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_embedding_tokens])
self.embeddings = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for n_tokens in l_embedding_tokens])
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
self.weight = nn.ParameterList([nn.Parameter( torch.tensor([1]) ) for i in range(levels)]) if levels is not None else None
@ -223,7 +223,7 @@ class AudioEmbedding(nn.Module):
# array of embeddings
# proms are [0, resp_levels]
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_embedding_tokens])
self.embeddings = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for n_tokens in l_embedding_tokens])
# further experimentation is needed to see if this actually is useful
self.sums = sums
#
@ -350,7 +350,7 @@ class AudioEncoder(nn.Module):
token_dim: int,
):
super().__init__()
self.embs = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for l in range(n_levels)])
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)])
self.proj = nn.Linear(8 * token_dim, 1 * token_dim)
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
@ -2201,7 +2201,7 @@ if __name__ == "__main__":
if is_from_pretrained:
n_vocab = end - start
embds[k] = torch.nn.Embedding( n_vocab, n_embd ).to(model.embed_tokens.weight)
embds[k] = torch.ml.Embedding( n_vocab, n_embd ).to(model.embed_tokens.weight)
embds[k].weight[:] = model.embed_tokens.weight[start:end, :]
if classifier_idx >= 0:
@ -2213,7 +2213,7 @@ if __name__ == "__main__":
heads[k].weight[:] = hf_model.lm_head.weight[start:end, :]
else:
embd_weight = state_dict[embd_name].unsqueeze(0) if state_dict[embd_name].dim() == 1 else state_dict[embd_name]
embds[k] = torch.nn.Embedding( embd_weight.shape[0], embd_weight.shape[1] ).to(device=device, dtype=dtype)
embds[k] = torch.ml.Embedding( embd_weight.shape[0], embd_weight.shape[1] ).to(device=device, dtype=dtype)
embds[k].load_state_dict({ "weight": embd_weight })
if classifier_idx >= 0:

View File

@ -32,7 +32,7 @@ if cfg.optimizations.bitsandbytes:
Linear = bnb.nn.Linear8bitLt
if cfg.optimizations.embedding:
Embedding = bnb.nn.modules.Embedding
Embedding = bnb.nn.StableEmbedding
"""
Embedding.forward = lambda self, input: ( self.norm(F.embedding(
input,