added picking final candidate based on sum of score instead of first candidate (this changes nothing).

This commit is contained in:
mrq 2023-09-13 13:19:11 -05:00
parent 23a5fdd645
commit 4aef798135
4 changed files with 85 additions and 36 deletions

View File

@ -121,12 +121,14 @@ class AR(Base):
stopped = torch.zeros(batch_size, device=device).bool()
state = {} if cfg.inference.recurrent_forward else None
sampling_beam_width_use_logs = True
scores = [ 1.0 ] * sampling_beam_width
if self.interleave:
max_steps *= self.n_prom_levels
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
# get next in sequence
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
logits = super().forward(
text_list=text_list,
@ -149,12 +151,23 @@ class AR(Base):
beam_width=sampling_beam_width,
)
# first step, expand batch
# we do it here because the sampler will already expand our logits list
if sampling_beam_width > 0 and batch_size == 1:
if sampling_beam_width > 0:
# expand tuple
r, s = r
# first step, expand batch
if batch_size == 1:
batch_size *= sampling_beam_width
text_list = text_list * sampling_beam_width
proms_list = proms_list * sampling_beam_width
resps_list = resps_list * sampling_beam_width
sequence_list = sequence_list * sampling_beam_width
stopped = torch.zeros(batch_size, device=device).bool()
# update scores
if sampling_beam_width_use_logs:
scores = [ (math.log(scores[i]) if scores[i] > 0 else 0) + math.log(score) for i, score in enumerate(s) ]
else:
scores = [ scores[i] * score for i, score in enumerate(s) ]
# append tokens
for i, ri in enumerate(r):
@ -168,6 +181,16 @@ class AR(Base):
if stopped.all().item():
break
# pick the best scoring candidate
# desu this is always going to be candidate 0
if sampling_beam_width and len(scores) > 0:
best_idx, best_score = (0, 0)
for idx, score in enumerate(scores):
if best_score > score:
best_idx, best_score = idx, score
sequence_list = [sequence_list[best_idx]]
res = [self._prune(r) for r in resps_list]
if self.interleave:
res = [self._deinterleave(r) for r in res]

View File

@ -5,6 +5,7 @@ import torch
from torch.nn.utils.rnn import pad_sequence
import random
import math
from einops import rearrange
from torch import Tensor
from tqdm import trange
@ -151,12 +152,14 @@ class AR_NAR(Base):
state = {} if cfg.inference.recurrent_forward else None
sampling_beam_width_use_logs = True
scores = [ 1.0 ] * sampling_beam_width
if self.interleave:
max_steps *= self.n_prom_levels
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
# get next in sequence
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
resps_list = self._unsqueeze_list(sequence_list)
logits = super().forward(
text_list=text_list,
@ -179,15 +182,24 @@ class AR_NAR(Base):
beam_width=sampling_beam_width,
)
# first step, expand batch
# we do it here because the sampler will already expand our logits list
if sampling_beam_width > 0 and batch_size == 1:
if sampling_beam_width > 0:
# expand tuple
r, s = r
# first step, expand batch
if batch_size == 1:
batch_size *= sampling_beam_width
text_list = text_list * sampling_beam_width
proms_list = proms_list * sampling_beam_width
sequence_list = sequence_list * sampling_beam_width
stopped = torch.zeros(batch_size, device=device).bool()
# update scores
if sampling_beam_width_use_logs:
scores = [ (math.log(scores[i]) if scores[i] > 0 else 0) + math.log(score) for i, score in enumerate(s) ]
else:
scores = [ scores[i] * score for i, score in enumerate(s) ]
# append tokens
for i, ri in enumerate(r):
if self.stop_token in ri:
@ -199,9 +211,15 @@ class AR_NAR(Base):
if stopped.all().item():
break
# pick the first candidate
if sampling_beam_width:
sequence_list = sequence_list[:1]
# pick the best scoring candidate
# desu this is always going to be candidate 0
if sampling_beam_width and len(scores) > 0:
best_idx, best_score = (0, 0)
for idx, score in enumerate(scores):
if best_score > score:
best_idx, best_score = idx, score
sequence_list = [sequence_list[best_idx]]
return [self._prune(r) for r in sequence_list]

View File

@ -119,6 +119,22 @@ def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf"
return logits
# picks the top K tokens amongst a batch of logits
# logits: [Tensor] list of logits
# candidates: [(batch, token)] list, where batch indicates the index of the logits the given token is from
def top_k_logits_list( logits_list, k ):
# ( batch, tokens ) => ( batch x tokens )
logits = torch.cat( logits_list )
candidates = list(torch.topk(logits.flatten(), k).indices.tolist()) # perform top-k across all logits
for i, index in enumerate(candidates):
t = []
N = np.prod(logits.size())
for n in logits.size():
N //= n
t.append(index // N)
index %= N
candidates[i] = tuple(t)
return candidates
# automagically parses a batch-list and returns it as a list
class Embedding(nn.Embedding):
@ -128,7 +144,7 @@ class Embedding(nn.Embedding):
return super().forward(torch.cat(x_list)).split([*map(len, x_list)])
class MultiEmbedding(nn.Embedding):
class MultiEmbedding(nn.Module):
"""
This embedding sums embeddings on different levels.
"""
@ -468,21 +484,13 @@ class Base(nn.Module):
# do beam search (naive implementation)
# picks the top-k across all batches, and re-batches those resultant tokens
# this doesn't do any other mumbo with previous logits
# returns the logit scores as well to be P-concatted with the previous scores
# to-do: not naively implement beam searching
if beam_width > 1:
# ( batch, tokens ) => ( batch x tokens )
flattened = torch.cat( logits )
candidates = list(torch.topk(flattened.flatten(), beam_width).indices.tolist()) # perform top-k across all logits
for i, index in enumerate(candidates):
t = []
N = np.prod(flattened.size())
for n in flattened.size():
N //= n
t.append(index // N)
index %= N
candidates[i] = tuple(t)
return [ torch.tensor(token, device=logits[batch].device, dtype=torch.int16).unsqueeze(dim=-1) for batch, token in candidates ] #, [ logits[batch] for batch, token in candidates ]
candidates = top_k_logits_list( logits, beam_width )
res = [ torch.tensor(token, device=logits[batch].device, dtype=torch.int16).unsqueeze(dim=-1) for batch, token in candidates ]
scores = [ logits[batch].flatten()[token] for batch, token in candidates ]
return res, scores
# and sample
# the original implementation used this instead of argmax; it's probably placebo but it performs better than argmax

View File

@ -152,10 +152,10 @@ def run_eval(engines, disabled_engines, eval_name, dl):
stats = {k: sum(v) / len(v) for k, v in stats.items()}
engines_stats.update({ f'{name}.{eval_name}': stats })
iteration = engines.global_step
engines_stats['it'] = iteration
engines_stats = {
f'{name}.{eval_name}': stats,
"it": engines.global_step,
}
#engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")