some day I'll get it right
This commit is contained in:
parent
b2907ae7e0
commit
f69aad9c65
|
@ -163,7 +163,6 @@ class Model:
|
||||||
arch_type: str = "transformer"
|
arch_type: str = "transformer"
|
||||||
training: bool = True
|
training: bool = True
|
||||||
interleave: bool = False
|
interleave: bool = False
|
||||||
use_multiembedding: bool = True # nasty bandaid I got myself into
|
|
||||||
frozen_params: list[str] = field(default_factory=lambda: [])
|
frozen_params: list[str] = field(default_factory=lambda: [])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -107,7 +107,7 @@ class AR_NAR(Base):
|
||||||
prev_list = resps_list
|
prev_list = resps_list
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
level = prev_list[0].shape[-1] - 1
|
level = prev_list[0].shape[-1]
|
||||||
|
|
||||||
if level >= self.n_resp_levels:
|
if level >= self.n_resp_levels:
|
||||||
break
|
break
|
||||||
|
|
|
@ -64,12 +64,9 @@ class MultiEmbedding(nn.Embedding):
|
||||||
|
|
||||||
def __init__(self, max_n_levels, n_tokens, token_dim, monolithic=False):
|
def __init__(self, max_n_levels, n_tokens, token_dim, monolithic=False):
|
||||||
super().__init__(max_n_levels, token_dim)
|
super().__init__(max_n_levels, token_dim)
|
||||||
|
self.monolithic = monolithic
|
||||||
self.max_n_levels = max_n_levels
|
self.max_n_levels = max_n_levels
|
||||||
self.n_tokens = n_tokens
|
self.n_tokens = n_tokens
|
||||||
self.monolithic = monolithic
|
|
||||||
if self.monolithic:
|
|
||||||
self.weights = nn.Parameter(torch.randn(2, max_n_levels, n_tokens, token_dim))
|
|
||||||
else:
|
|
||||||
self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim))
|
self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim))
|
||||||
|
|
||||||
# to-do: select quant level from given quant_levels tensor if given (i.e. through the resp_emb)
|
# to-do: select quant level from given quant_levels tensor if given (i.e. through the resp_emb)
|
||||||
|
@ -78,14 +75,16 @@ class MultiEmbedding(nn.Embedding):
|
||||||
if len(x_list) == 0:
|
if len(x_list) == 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
# this "strategy" will reserve the weight[0] for te AR and weight[1:] for the NAR
|
||||||
|
# the NAR cannot share RVQ-bin level 0 with the AR for the resp_emb
|
||||||
if self.monolithic:
|
if self.monolithic:
|
||||||
w = self.weights[0 if quant_levels is None else 1]
|
w = self.weight[:1] if quant_levels is None else self.weight[1:]
|
||||||
else:
|
else:
|
||||||
w = self.weight
|
w = self.weight
|
||||||
|
|
||||||
padded_x_list = []
|
padded_x_list = []
|
||||||
|
|
||||||
for xi in x_list:
|
for i, xi in enumerate(x_list):
|
||||||
xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k
|
xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k
|
||||||
wi = w.shape[0] - xi.shape[1]
|
wi = w.shape[0] - xi.shape[1]
|
||||||
xi = F.pad(xi, (0, 0, 0, wi)) # t l k
|
xi = F.pad(xi, (0, 0, 0, wi)) # t l k
|
||||||
|
|
Loading…
Reference in New Issue
Block a user