UGH
This commit is contained in:
parent
22ffaf3a33
commit
2567e082b5
|
@ -117,10 +117,11 @@ class AR(Base):
|
|||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
|
||||
resps_list: list[Tensor] = [ torch.zeros(0, device=device).to(torch.int16) for _ in text_list ]
|
||||
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in text_list ]
|
||||
stopped = torch.zeros(batch_size, device=device).bool()
|
||||
|
||||
state = {} if cfg.inference.recurrent_forward else None
|
||||
|
||||
sampling_beam_width_use_logs = True
|
||||
scores = [ 1.0 ] * sampling_beam_width
|
||||
|
||||
|
@ -129,12 +130,12 @@ class AR(Base):
|
|||
|
||||
# get next in sequence
|
||||
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
|
||||
|
||||
resps_list = self._unsqueeze_list(sequence_list)
|
||||
logits = super().forward(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=self._unsqueeze_list(resps_list),
|
||||
quant_levels=None,
|
||||
resps_list=resps_list,
|
||||
|
||||
state=state
|
||||
)
|
||||
|
||||
|
@ -173,8 +174,7 @@ class AR(Base):
|
|||
for i, ri in enumerate(r):
|
||||
if self.stop_token in ri:
|
||||
stopped[i] = True
|
||||
|
||||
resps_list[i] = torch.cat([resps_list[i], ri])
|
||||
sequence_list[i] = torch.cat([sequence_list[i], ri])
|
||||
|
||||
# stop token found
|
||||
stopped |= r == self.stop_token
|
||||
|
@ -191,10 +191,9 @@ class AR(Base):
|
|||
|
||||
sequence_list = [sequence_list[best_idx]]
|
||||
|
||||
res = [self._prune(r) for r in resps_list]
|
||||
if self.interleave:
|
||||
res = [self._deinterleave(r) for r in res]
|
||||
return res
|
||||
sequence_list = [self._deinterleave(r) for r in sequence_list]
|
||||
return [self._prune(r) for r in sequence_list]
|
||||
|
||||
|
||||
def example_usage():
|
||||
|
|
|
@ -150,7 +150,7 @@ class MultiEmbedding(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, max_n_levels, n_tokens, token_dim, monolithic=False):
|
||||
super().__init__(max_n_levels, token_dim)
|
||||
super().__init__()
|
||||
self.monolithic = monolithic
|
||||
self.max_n_levels = max_n_levels
|
||||
self.n_tokens = n_tokens
|
||||
|
|
|
@ -69,6 +69,7 @@ class NAR(Base):
|
|||
sampling_repetition_penalty: float = 1.0,
|
||||
sampling_repetition_penalty_decay: float = 0.0,
|
||||
sampling_length_penalty: float = 0.0,
|
||||
sampling_beam_width: int = 0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -129,9 +130,9 @@ class NAR(Base):
|
|||
|
||||
resps_list = super().sample(
|
||||
logits=logits,
|
||||
resps_list=resps_list,
|
||||
resps_list=prev_list,
|
||||
quant_levels=quant_levels,
|
||||
|
||||
|
||||
temperature=sampling_temperature,
|
||||
top_p=sampling_top_p,
|
||||
top_k=sampling_top_k,
|
||||
|
|
Loading…
Reference in New Issue
Block a user