This commit is contained in:
mrq 2023-09-16 00:26:13 -05:00
parent 22ffaf3a33
commit 2567e082b5
3 changed files with 12 additions and 12 deletions

View File

@ -117,10 +117,11 @@ class AR(Base):
device = text_list[0].device
batch_size = len(text_list)
resps_list: list[Tensor] = [ torch.zeros(0, device=device).to(torch.int16) for _ in text_list ]
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in text_list ]
stopped = torch.zeros(batch_size, device=device).bool()
state = {} if cfg.inference.recurrent_forward else None
sampling_beam_width_use_logs = True
scores = [ 1.0 ] * sampling_beam_width
@ -129,12 +130,12 @@ class AR(Base):
# get next in sequence
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
resps_list = self._unsqueeze_list(sequence_list)
logits = super().forward(
text_list=text_list,
proms_list=proms_list,
resps_list=self._unsqueeze_list(resps_list),
quant_levels=None,
resps_list=resps_list,
state=state
)
@ -173,8 +174,7 @@ class AR(Base):
for i, ri in enumerate(r):
if self.stop_token in ri:
stopped[i] = True
resps_list[i] = torch.cat([resps_list[i], ri])
sequence_list[i] = torch.cat([sequence_list[i], ri])
# stop token found
stopped |= r == self.stop_token
@ -191,10 +191,9 @@ class AR(Base):
sequence_list = [sequence_list[best_idx]]
res = [self._prune(r) for r in resps_list]
if self.interleave:
res = [self._deinterleave(r) for r in res]
return res
sequence_list = [self._deinterleave(r) for r in sequence_list]
return [self._prune(r) for r in sequence_list]
def example_usage():

View File

@ -150,7 +150,7 @@ class MultiEmbedding(nn.Module):
"""
def __init__(self, max_n_levels, n_tokens, token_dim, monolithic=False):
super().__init__(max_n_levels, token_dim)
super().__init__()
self.monolithic = monolithic
self.max_n_levels = max_n_levels
self.n_tokens = n_tokens

View File

@ -69,6 +69,7 @@ class NAR(Base):
sampling_repetition_penalty: float = 1.0,
sampling_repetition_penalty_decay: float = 0.0,
sampling_length_penalty: float = 0.0,
sampling_beam_width: int = 0,
):
"""
Args:
@ -129,7 +130,7 @@ class NAR(Base):
resps_list = super().sample(
logits=logits,
resps_list=resps_list,
resps_list=prev_list,
quant_levels=quant_levels,
temperature=sampling_temperature,