This commit is contained in:
mrq 2024-10-04 22:30:47 -05:00
parent a507b769a1
commit 84c7419001
2 changed files with 7 additions and 5 deletions

View File

@ -276,7 +276,7 @@ class AR_NAR(Base):
# get next in sequence # get next in sequence
for n in trange(max_steps // max(1, self.causal_size), desc="AR", disable=disable_tqdm): for n in trange(max_steps // max(1, self.causal_size), desc="AR", disable=disable_tqdm):
# # it would technically be faster to just append the new token's embedding to the inputs, but there's a VERY small performance gain from doing it, so it's not worth it
text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ] text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ]
resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ] resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ]
@ -291,6 +291,7 @@ class AR_NAR(Base):
quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ] quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ]
) )
# to-do: find an elegant way to write this
if state is not None: if state is not None:
logits, state = super().forward( logits, state = super().forward(
inputs=inputs, inputs=inputs,

View File

@ -15,6 +15,7 @@ import torch.nn.functional as F
import random import random
import numpy as np import numpy as np
import re import re
from time import perf_counter
from typing import Literal, overload, Optional, Tuple from typing import Literal, overload, Optional, Tuple
from functools import partial from functools import partial
@ -1479,8 +1480,8 @@ class Base(nn.Module):
elif self.causal: elif self.causal:
logits = [ logit[-self.causal_size:] for logit in logits ] logits = [ logit[-self.causal_size:] for logit in logits ]
devices = [ logit.device for logit in logits ] # this might actually slow things down a bit slightly-er?
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ] #logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
# (NAR) disable stop token # (NAR) disable stop token
if quant_levels is not None and "ar" in self.capabilities: if quant_levels is not None and "ar" in self.capabilities:
@ -1494,12 +1495,12 @@ class Base(nn.Module):
return [ logit.argmax(dim=1) for logit in logits ] return [ logit.argmax(dim=1) for logit in logits ]
# perform repetition penalizing # perform repetition penalizing
if "len" not in self.capabilities and prev_list is not None: if "len" not in self.capabilities and prev_list is not None and repetition_penalty != 1.0:
# to-do: figure out a faster way to handle tolist() # to-do: figure out a faster way to handle tolist()
logits = [ reptition_penalize(logit, previous=prevs[:, -1].tolist() if prevs.dim() > 1 else prevs.tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ] logits = [ reptition_penalize(logit, previous=prevs[:, -1].tolist() if prevs.dim() > 1 else prevs.tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
# (AR) perform length penalizing # (AR) perform length penalizing
if quant_levels is None and self.causal and prev_list is not None: if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0:
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ] logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ]
# perform top_k/top_p filtering of our logits # perform top_k/top_p filtering of our logits