Absorb stop token into the model

This commit is contained in:
enhuiz 2023-01-12 00:09:11 +08:00
parent 43483bb394
commit 6c5f250faa

View File

@ -203,13 +203,22 @@ class VALLEAR(nn.Module):
# Here, simply use num_tokens := max(num_text_tokens, num_prompt_tokens, num_output_tokens)
self.text_emb = ListEmbedding(num_tokens, d_model)
self.prompt_emb = ListEmbedding(num_tokens, d_model)
self.output_emb = ListEmbedding(num_tokens, d_model)
# +1 to include the stop token
self.output_emb = ListEmbedding(num_tokens + 1, d_model)
self.sin_emb = SinusodialEmbedding(d_model)
self.sep = nn.Parameter(torch.randn(d_model)) # start of sequence token
self.blocks = nn.ModuleList(
[Block(d_model, num_heads, dropout) for _ in range(num_layers)]
)
self.fc = nn.Linear(d_model, num_tokens)
self.fc = nn.Linear(d_model, num_tokens + 1)
@property
def num_tokens(self):
return self.output_emb.num_embeddings - 1
@property
def _stop_index(self):
return self.num_tokens
@property
def _ignore_index(self):
@ -264,7 +273,7 @@ class VALLEAR(nn.Module):
# make y_list earlier as it is future that is unknown
for i in range(len(y_list)):
y_list[i] = y_list[i].roll(-1, dims=0)
y_list[i][-1] = self._ignore_index
y_list[i][-1] = self._stop_index
self.loss = dict(
nll=F.cross_entropy(
@ -278,16 +287,10 @@ class VALLEAR(nn.Module):
return logits
@staticmethod
def _prune(l: Tensor, stop_token: int | None):
if stop_token is None:
return l
indices = (l == stop_token).nonzero()
def _prune(self, l: Tensor):
indices = (l == self._stop_index).nonzero()
if len(indices) == 0:
return l
return l[: indices[0].item()]
def generate(
@ -295,7 +298,6 @@ class VALLEAR(nn.Module):
text_list: list[Tensor],
prompt_list: list[Tensor],
max_steps: int = 1000,
stop_token: int | None = None,
):
device = text_list[0].device
output_list: list[Tensor] = [
@ -311,12 +313,12 @@ class VALLEAR(nn.Module):
)
o = Categorical(logits=logits).sample()
for i, oi in enumerate(o):
if oi.item() == stop_token:
if oi.item() == self._stop_index:
stopped[i] = True
output_list[i] = torch.cat([output_list[i], oi[None]])
if all(stopped):
break
pruned = [self._prune(o, stop_token) for o in output_list]
pruned = [self._prune(o) for o in output_list]
return pruned
@ -325,9 +327,8 @@ def example_usage():
device = "cuda"
test_qnt = torch.load("data/test/test.qnt.pt")[0, 0].to(device)
num_qnts = 1024 + 1
eoq = num_qnts - 1
qnt = torch.load("data/test/test.qnt.pt")[0, 0].to(device)
num_qnts = 1024
model = VALLEAR(num_qnts).to(device)
@ -342,15 +343,14 @@ def example_usage():
]
output_list = [
torch.tensor([1, 2, 3, eoq], device=device),
torch.tensor([*test_qnt, eoq], device=device),
torch.tensor([1, 2, 3], device=device),
torch.tensor(qnt, device=device),
]
out = model.generate(
text_list,
prompt_list,
max_steps=200,
stop_token=eoq,
)
print(out)
@ -368,13 +368,9 @@ def example_usage():
if i % 20 == 0:
print(f"iter={i}, {losses}.")
out = model.generate(
text_list,
prompt_list,
max_steps=200,
stop_token=eoq,
)
out = model.generate(text_list, prompt_list, max_steps=200)
print(qnt)
print(out)
from ..emb.qnt import decode