ugh
This commit is contained in:
parent
a3b7260514
commit
35510c954c
|
@ -135,7 +135,7 @@ def _interleave_sequence_flatten( input: list[torch.Tensor] ):
|
|||
|
||||
# automagically parses a batch-list and returns it as a list
|
||||
"""
|
||||
class Embedding(nn.Embedding):
|
||||
class Embedding(ml.Embedding):
|
||||
def forward(self, x_list: list[Tensor]) -> list[Tensor]:
|
||||
if len(x_list) == 0:
|
||||
return []
|
||||
|
@ -192,7 +192,7 @@ class AudioEmbedding_Old(nn.Module):
|
|||
# array of embeddings
|
||||
# proms are [0, resp_levels]
|
||||
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
|
||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_embedding_tokens])
|
||||
self.embeddings = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for n_tokens in l_embedding_tokens])
|
||||
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
|
||||
self.weight = nn.ParameterList([nn.Parameter( torch.tensor([1]) ) for i in range(levels)]) if levels is not None else None
|
||||
|
||||
|
@ -223,7 +223,7 @@ class AudioEmbedding(nn.Module):
|
|||
# array of embeddings
|
||||
# proms are [0, resp_levels]
|
||||
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
|
||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_embedding_tokens])
|
||||
self.embeddings = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for n_tokens in l_embedding_tokens])
|
||||
# further experimentation is needed to see if this actually is useful
|
||||
self.sums = sums
|
||||
#
|
||||
|
@ -350,7 +350,7 @@ class AudioEncoder(nn.Module):
|
|||
token_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.embs = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for l in range(n_levels)])
|
||||
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)])
|
||||
self.proj = nn.Linear(8 * token_dim, 1 * token_dim)
|
||||
|
||||
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
|
||||
|
@ -2201,7 +2201,7 @@ if __name__ == "__main__":
|
|||
if is_from_pretrained:
|
||||
n_vocab = end - start
|
||||
|
||||
embds[k] = torch.nn.Embedding( n_vocab, n_embd ).to(model.embed_tokens.weight)
|
||||
embds[k] = torch.ml.Embedding( n_vocab, n_embd ).to(model.embed_tokens.weight)
|
||||
embds[k].weight[:] = model.embed_tokens.weight[start:end, :]
|
||||
|
||||
if classifier_idx >= 0:
|
||||
|
@ -2213,7 +2213,7 @@ if __name__ == "__main__":
|
|||
heads[k].weight[:] = hf_model.lm_head.weight[start:end, :]
|
||||
else:
|
||||
embd_weight = state_dict[embd_name].unsqueeze(0) if state_dict[embd_name].dim() == 1 else state_dict[embd_name]
|
||||
embds[k] = torch.nn.Embedding( embd_weight.shape[0], embd_weight.shape[1] ).to(device=device, dtype=dtype)
|
||||
embds[k] = torch.ml.Embedding( embd_weight.shape[0], embd_weight.shape[1] ).to(device=device, dtype=dtype)
|
||||
embds[k].load_state_dict({ "weight": embd_weight })
|
||||
|
||||
if classifier_idx >= 0:
|
||||
|
|
|
@ -32,7 +32,7 @@ if cfg.optimizations.bitsandbytes:
|
|||
Linear = bnb.nn.Linear8bitLt
|
||||
|
||||
if cfg.optimizations.embedding:
|
||||
Embedding = bnb.nn.modules.Embedding
|
||||
Embedding = bnb.nn.StableEmbedding
|
||||
"""
|
||||
Embedding.forward = lambda self, input: ( self.norm(F.embedding(
|
||||
input,
|
||||
|
|
Loading…
Reference in New Issue
Block a user