some day I'll get it right
This commit is contained in:
parent
b2907ae7e0
commit
f69aad9c65
|
@ -163,7 +163,6 @@ class Model:
|
|||
arch_type: str = "transformer"
|
||||
training: bool = True
|
||||
interleave: bool = False
|
||||
use_multiembedding: bool = True # nasty bandaid I got myself into
|
||||
frozen_params: list[str] = field(default_factory=lambda: [])
|
||||
|
||||
@property
|
||||
|
|
|
@ -107,7 +107,7 @@ class AR_NAR(Base):
|
|||
prev_list = resps_list
|
||||
|
||||
while True:
|
||||
level = prev_list[0].shape[-1] - 1
|
||||
level = prev_list[0].shape[-1]
|
||||
|
||||
if level >= self.n_resp_levels:
|
||||
break
|
||||
|
|
|
@ -64,13 +64,10 @@ class MultiEmbedding(nn.Embedding):
|
|||
|
||||
def __init__(self, max_n_levels, n_tokens, token_dim, monolithic=False):
|
||||
super().__init__(max_n_levels, token_dim)
|
||||
self.monolithic = monolithic
|
||||
self.max_n_levels = max_n_levels
|
||||
self.n_tokens = n_tokens
|
||||
self.monolithic = monolithic
|
||||
if self.monolithic:
|
||||
self.weights = nn.Parameter(torch.randn(2, max_n_levels, n_tokens, token_dim))
|
||||
else:
|
||||
self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim))
|
||||
self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim))
|
||||
|
||||
# to-do: select quant level from given quant_levels tensor if given (i.e. through the resp_emb)
|
||||
# I imagine this is an oversight in the NAR.
|
||||
|
@ -78,14 +75,16 @@ class MultiEmbedding(nn.Embedding):
|
|||
if len(x_list) == 0:
|
||||
return []
|
||||
|
||||
# this "strategy" will reserve the weight[0] for te AR and weight[1:] for the NAR
|
||||
# the NAR cannot share RVQ-bin level 0 with the AR for the resp_emb
|
||||
if self.monolithic:
|
||||
w = self.weights[0 if quant_levels is None else 1]
|
||||
w = self.weight[:1] if quant_levels is None else self.weight[1:]
|
||||
else:
|
||||
w = self.weight
|
||||
|
||||
padded_x_list = []
|
||||
|
||||
for xi in x_list:
|
||||
for i, xi in enumerate(x_list):
|
||||
xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k
|
||||
wi = w.shape[0] - xi.shape[1]
|
||||
xi = F.pad(xi, (0, 0, 0, wi)) # t l k
|
||||
|
|
Loading…
Reference in New Issue
Block a user