Simplify interface
This commit is contained in:
parent
ea5e438fdb
commit
5e4ef084b8
|
@ -23,14 +23,16 @@ class AR(Base):
|
||||||
indices = (l == self.stop_token).nonzero()
|
indices = (l == self.stop_token).nonzero()
|
||||||
if len(indices) == 0:
|
if len(indices) == 0:
|
||||||
return l
|
return l
|
||||||
return l[: indices[0].item()]
|
return l[: indices.min().item()]
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
text_list: list[Tensor],
|
text_list: list[Tensor],
|
||||||
proms_list: list[Tensor],
|
proms_list: list[Tensor],
|
||||||
resp_list: list[Tensor],
|
resp_list: list[Tensor] | None = None,
|
||||||
|
max_steps: int = 1000,
|
||||||
):
|
):
|
||||||
|
if resp_list is not None:
|
||||||
return super().forward(
|
return super().forward(
|
||||||
text_list,
|
text_list,
|
||||||
proms_list,
|
proms_list,
|
||||||
|
@ -40,12 +42,14 @@ class AR(Base):
|
||||||
shift_targ_list=True,
|
shift_targ_list=True,
|
||||||
return_all_resp=False,
|
return_all_resp=False,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
return self._generate(text_list, proms_list, max_steps)
|
||||||
|
|
||||||
def generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
text_list: list[Tensor],
|
text_list: list[Tensor],
|
||||||
proms_list: list[Tensor],
|
proms_list: list[Tensor],
|
||||||
max_steps: int = 1000,
|
max_steps: int,
|
||||||
):
|
):
|
||||||
device = text_list[0].device
|
device = text_list[0].device
|
||||||
resp_list: list[Tensor] = [
|
resp_list: list[Tensor] = [
|
||||||
|
@ -93,11 +97,7 @@ def example_usage():
|
||||||
qnt.to(device),
|
qnt.to(device),
|
||||||
]
|
]
|
||||||
|
|
||||||
out = model.generate(
|
out = model(text_list, proms_list, max_steps=200)
|
||||||
text_list,
|
|
||||||
proms_list,
|
|
||||||
max_steps=200,
|
|
||||||
)
|
|
||||||
|
|
||||||
print(out)
|
print(out)
|
||||||
|
|
||||||
|
@ -114,7 +114,7 @@ def example_usage():
|
||||||
if i % 20 == 0:
|
if i % 20 == 0:
|
||||||
print(f"iter={i}, {losses}.")
|
print(f"iter={i}, {losses}.")
|
||||||
|
|
||||||
out = model.generate(text_list, proms_list, max_steps=200)
|
out = model(text_list, proms_list, max_steps=200)
|
||||||
|
|
||||||
print(qnt)
|
print(qnt)
|
||||||
print(out)
|
print(out)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user