Simplify interface

This commit is contained in:
enhuiz 2023-01-12 14:24:35 +08:00
parent ea5e438fdb
commit 5e4ef084b8

View File

@ -23,14 +23,16 @@ class AR(Base):
indices = (l == self.stop_token).nonzero() indices = (l == self.stop_token).nonzero()
if len(indices) == 0: if len(indices) == 0:
return l return l
return l[: indices[0].item()] return l[: indices.min().item()]
def forward( def forward(
self, self,
text_list: list[Tensor], text_list: list[Tensor],
proms_list: list[Tensor], proms_list: list[Tensor],
resp_list: list[Tensor], resp_list: list[Tensor] | None = None,
max_steps: int = 1000,
): ):
if resp_list is not None:
return super().forward( return super().forward(
text_list, text_list,
proms_list, proms_list,
@ -40,12 +42,14 @@ class AR(Base):
shift_targ_list=True, shift_targ_list=True,
return_all_resp=False, return_all_resp=False,
) )
else:
return self._generate(text_list, proms_list, max_steps)
def generate( def _generate(
self, self,
text_list: list[Tensor], text_list: list[Tensor],
proms_list: list[Tensor], proms_list: list[Tensor],
max_steps: int = 1000, max_steps: int,
): ):
device = text_list[0].device device = text_list[0].device
resp_list: list[Tensor] = [ resp_list: list[Tensor] = [
@ -93,11 +97,7 @@ def example_usage():
qnt.to(device), qnt.to(device),
] ]
out = model.generate( out = model(text_list, proms_list, max_steps=200)
text_list,
proms_list,
max_steps=200,
)
print(out) print(out)
@ -114,7 +114,7 @@ def example_usage():
if i % 20 == 0: if i % 20 == 0:
print(f"iter={i}, {losses}.") print(f"iter={i}, {losses}.")
out = model.generate(text_list, proms_list, max_steps=200) out = model(text_list, proms_list, max_steps=200)
print(qnt) print(qnt)
print(out) print(out)