Absorb stop token into the model
This commit is contained in:
parent
43483bb394
commit
6c5f250faa
@ -203,13 +203,22 @@ class VALLEAR(nn.Module):
|
||||
# Here, simply use num_tokens := max(num_text_tokens, num_prompt_tokens, num_output_tokens)
|
||||
self.text_emb = ListEmbedding(num_tokens, d_model)
|
||||
self.prompt_emb = ListEmbedding(num_tokens, d_model)
|
||||
self.output_emb = ListEmbedding(num_tokens, d_model)
|
||||
# +1 to include the stop token
|
||||
self.output_emb = ListEmbedding(num_tokens + 1, d_model)
|
||||
self.sin_emb = SinusodialEmbedding(d_model)
|
||||
self.sep = nn.Parameter(torch.randn(d_model)) # start of sequence token
|
||||
self.blocks = nn.ModuleList(
|
||||
[Block(d_model, num_heads, dropout) for _ in range(num_layers)]
|
||||
)
|
||||
self.fc = nn.Linear(d_model, num_tokens)
|
||||
self.fc = nn.Linear(d_model, num_tokens + 1)
|
||||
|
||||
@property
|
||||
def num_tokens(self):
|
||||
return self.output_emb.num_embeddings - 1
|
||||
|
||||
@property
|
||||
def _stop_index(self):
|
||||
return self.num_tokens
|
||||
|
||||
@property
|
||||
def _ignore_index(self):
|
||||
@ -264,7 +273,7 @@ class VALLEAR(nn.Module):
|
||||
# make y_list earlier as it is future that is unknown
|
||||
for i in range(len(y_list)):
|
||||
y_list[i] = y_list[i].roll(-1, dims=0)
|
||||
y_list[i][-1] = self._ignore_index
|
||||
y_list[i][-1] = self._stop_index
|
||||
|
||||
self.loss = dict(
|
||||
nll=F.cross_entropy(
|
||||
@ -278,16 +287,10 @@ class VALLEAR(nn.Module):
|
||||
|
||||
return logits
|
||||
|
||||
@staticmethod
|
||||
def _prune(l: Tensor, stop_token: int | None):
|
||||
if stop_token is None:
|
||||
return l
|
||||
|
||||
indices = (l == stop_token).nonzero()
|
||||
|
||||
def _prune(self, l: Tensor):
|
||||
indices = (l == self._stop_index).nonzero()
|
||||
if len(indices) == 0:
|
||||
return l
|
||||
|
||||
return l[: indices[0].item()]
|
||||
|
||||
def generate(
|
||||
@ -295,7 +298,6 @@ class VALLEAR(nn.Module):
|
||||
text_list: list[Tensor],
|
||||
prompt_list: list[Tensor],
|
||||
max_steps: int = 1000,
|
||||
stop_token: int | None = None,
|
||||
):
|
||||
device = text_list[0].device
|
||||
output_list: list[Tensor] = [
|
||||
@ -311,12 +313,12 @@ class VALLEAR(nn.Module):
|
||||
)
|
||||
o = Categorical(logits=logits).sample()
|
||||
for i, oi in enumerate(o):
|
||||
if oi.item() == stop_token:
|
||||
if oi.item() == self._stop_index:
|
||||
stopped[i] = True
|
||||
output_list[i] = torch.cat([output_list[i], oi[None]])
|
||||
if all(stopped):
|
||||
break
|
||||
pruned = [self._prune(o, stop_token) for o in output_list]
|
||||
pruned = [self._prune(o) for o in output_list]
|
||||
return pruned
|
||||
|
||||
|
||||
@ -325,9 +327,8 @@ def example_usage():
|
||||
|
||||
device = "cuda"
|
||||
|
||||
test_qnt = torch.load("data/test/test.qnt.pt")[0, 0].to(device)
|
||||
num_qnts = 1024 + 1
|
||||
eoq = num_qnts - 1
|
||||
qnt = torch.load("data/test/test.qnt.pt")[0, 0].to(device)
|
||||
num_qnts = 1024
|
||||
|
||||
model = VALLEAR(num_qnts).to(device)
|
||||
|
||||
@ -342,15 +343,14 @@ def example_usage():
|
||||
]
|
||||
|
||||
output_list = [
|
||||
torch.tensor([1, 2, 3, eoq], device=device),
|
||||
torch.tensor([*test_qnt, eoq], device=device),
|
||||
torch.tensor([1, 2, 3], device=device),
|
||||
torch.tensor(qnt, device=device),
|
||||
]
|
||||
|
||||
out = model.generate(
|
||||
text_list,
|
||||
prompt_list,
|
||||
max_steps=200,
|
||||
stop_token=eoq,
|
||||
)
|
||||
|
||||
print(out)
|
||||
@ -368,13 +368,9 @@ def example_usage():
|
||||
if i % 20 == 0:
|
||||
print(f"iter={i}, {losses}.")
|
||||
|
||||
out = model.generate(
|
||||
text_list,
|
||||
prompt_list,
|
||||
max_steps=200,
|
||||
stop_token=eoq,
|
||||
)
|
||||
out = model.generate(text_list, prompt_list, max_steps=200)
|
||||
|
||||
print(qnt)
|
||||
print(out)
|
||||
|
||||
from ..emb.qnt import decode
|
||||
|
||||
Loading…
Reference in New Issue
Block a user