faster
This commit is contained in:
parent
a507b769a1
commit
84c7419001
|
@ -276,7 +276,7 @@ class AR_NAR(Base):
|
|||
|
||||
# get next in sequence
|
||||
for n in trange(max_steps // max(1, self.causal_size), desc="AR", disable=disable_tqdm):
|
||||
#
|
||||
# it would technically be faster to just append the new token's embedding to the inputs, but there's a VERY small performance gain from doing it, so it's not worth it
|
||||
text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ]
|
||||
resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ]
|
||||
|
||||
|
@ -291,6 +291,7 @@ class AR_NAR(Base):
|
|||
quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ]
|
||||
)
|
||||
|
||||
# to-do: find an elegant way to write this
|
||||
if state is not None:
|
||||
logits, state = super().forward(
|
||||
inputs=inputs,
|
||||
|
|
|
@ -15,6 +15,7 @@ import torch.nn.functional as F
|
|||
import random
|
||||
import numpy as np
|
||||
import re
|
||||
from time import perf_counter
|
||||
|
||||
from typing import Literal, overload, Optional, Tuple
|
||||
from functools import partial
|
||||
|
@ -1479,8 +1480,8 @@ class Base(nn.Module):
|
|||
elif self.causal:
|
||||
logits = [ logit[-self.causal_size:] for logit in logits ]
|
||||
|
||||
devices = [ logit.device for logit in logits ]
|
||||
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
||||
# this might actually slow things down a bit slightly-er?
|
||||
#logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
||||
|
||||
# (NAR) disable stop token
|
||||
if quant_levels is not None and "ar" in self.capabilities:
|
||||
|
@ -1494,12 +1495,12 @@ class Base(nn.Module):
|
|||
return [ logit.argmax(dim=1) for logit in logits ]
|
||||
|
||||
# perform repetition penalizing
|
||||
if "len" not in self.capabilities and prev_list is not None:
|
||||
if "len" not in self.capabilities and prev_list is not None and repetition_penalty != 1.0:
|
||||
# to-do: figure out a faster way to handle tolist()
|
||||
logits = [ reptition_penalize(logit, previous=prevs[:, -1].tolist() if prevs.dim() > 1 else prevs.tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
|
||||
|
||||
# (AR) perform length penalizing
|
||||
if quant_levels is None and self.causal and prev_list is not None:
|
||||
if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0:
|
||||
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ]
|
||||
|
||||
# perform top_k/top_p filtering of our logits
|
||||
|
|
Loading…
Reference in New Issue
Block a user