added picking final candidate based on sum of score instead of first candidate (this changes nothing).

This commit is contained in:
mrq 2023-09-13 13:19:11 -05:00
parent 23a5fdd645
commit 4aef798135
4 changed files with 85 additions and 36 deletions

View File

@ -121,12 +121,14 @@ class AR(Base):
stopped = torch.zeros(batch_size, device=device).bool() stopped = torch.zeros(batch_size, device=device).bool()
state = {} if cfg.inference.recurrent_forward else None state = {} if cfg.inference.recurrent_forward else None
sampling_beam_width_use_logs = True
scores = [ 1.0 ] * sampling_beam_width
if self.interleave: if self.interleave:
max_steps *= self.n_prom_levels max_steps *= self.n_prom_levels
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
# get next in sequence # get next in sequence
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
logits = super().forward( logits = super().forward(
text_list=text_list, text_list=text_list,
@ -149,12 +151,23 @@ class AR(Base):
beam_width=sampling_beam_width, beam_width=sampling_beam_width,
) )
# first step, expand batch
# we do it here because the sampler will already expand our logits list # we do it here because the sampler will already expand our logits list
if sampling_beam_width > 0 and batch_size == 1: if sampling_beam_width > 0:
# expand tuple
r, s = r
# first step, expand batch
if batch_size == 1:
batch_size *= sampling_beam_width
text_list = text_list * sampling_beam_width text_list = text_list * sampling_beam_width
proms_list = proms_list * sampling_beam_width proms_list = proms_list * sampling_beam_width
resps_list = resps_list * sampling_beam_width sequence_list = sequence_list * sampling_beam_width
stopped = torch.zeros(batch_size, device=device).bool()
# update scores
if sampling_beam_width_use_logs:
scores = [ (math.log(scores[i]) if scores[i] > 0 else 0) + math.log(score) for i, score in enumerate(s) ]
else:
scores = [ scores[i] * score for i, score in enumerate(s) ]
# append tokens # append tokens
for i, ri in enumerate(r): for i, ri in enumerate(r):
@ -168,6 +181,16 @@ class AR(Base):
if stopped.all().item(): if stopped.all().item():
break break
# pick the best scoring candidate
# desu this is always going to be candidate 0
if sampling_beam_width and len(scores) > 0:
best_idx, best_score = (0, 0)
for idx, score in enumerate(scores):
if best_score > score:
best_idx, best_score = idx, score
sequence_list = [sequence_list[best_idx]]
res = [self._prune(r) for r in resps_list] res = [self._prune(r) for r in resps_list]
if self.interleave: if self.interleave:
res = [self._deinterleave(r) for r in res] res = [self._deinterleave(r) for r in res]

View File

@ -5,6 +5,7 @@ import torch
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
import random import random
import math
from einops import rearrange from einops import rearrange
from torch import Tensor from torch import Tensor
from tqdm import trange from tqdm import trange
@ -151,12 +152,14 @@ class AR_NAR(Base):
state = {} if cfg.inference.recurrent_forward else None state = {} if cfg.inference.recurrent_forward else None
sampling_beam_width_use_logs = True
scores = [ 1.0 ] * sampling_beam_width
if self.interleave: if self.interleave:
max_steps *= self.n_prom_levels max_steps *= self.n_prom_levels
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
# get next in sequence # get next in sequence
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
resps_list = self._unsqueeze_list(sequence_list) resps_list = self._unsqueeze_list(sequence_list)
logits = super().forward( logits = super().forward(
text_list=text_list, text_list=text_list,
@ -179,15 +182,24 @@ class AR_NAR(Base):
beam_width=sampling_beam_width, beam_width=sampling_beam_width,
) )
# first step, expand batch
# we do it here because the sampler will already expand our logits list # we do it here because the sampler will already expand our logits list
if sampling_beam_width > 0 and batch_size == 1: if sampling_beam_width > 0:
# expand tuple
r, s = r
# first step, expand batch
if batch_size == 1:
batch_size *= sampling_beam_width batch_size *= sampling_beam_width
text_list = text_list * sampling_beam_width text_list = text_list * sampling_beam_width
proms_list = proms_list * sampling_beam_width proms_list = proms_list * sampling_beam_width
sequence_list = sequence_list * sampling_beam_width sequence_list = sequence_list * sampling_beam_width
stopped = torch.zeros(batch_size, device=device).bool() stopped = torch.zeros(batch_size, device=device).bool()
# update scores
if sampling_beam_width_use_logs:
scores = [ (math.log(scores[i]) if scores[i] > 0 else 0) + math.log(score) for i, score in enumerate(s) ]
else:
scores = [ scores[i] * score for i, score in enumerate(s) ]
# append tokens # append tokens
for i, ri in enumerate(r): for i, ri in enumerate(r):
if self.stop_token in ri: if self.stop_token in ri:
@ -199,9 +211,15 @@ class AR_NAR(Base):
if stopped.all().item(): if stopped.all().item():
break break
# pick the first candidate # pick the best scoring candidate
if sampling_beam_width: # desu this is always going to be candidate 0
sequence_list = sequence_list[:1] if sampling_beam_width and len(scores) > 0:
best_idx, best_score = (0, 0)
for idx, score in enumerate(scores):
if best_score > score:
best_idx, best_score = idx, score
sequence_list = [sequence_list[best_idx]]
return [self._prune(r) for r in sequence_list] return [self._prune(r) for r in sequence_list]

View File

@ -119,6 +119,22 @@ def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf"
return logits return logits
# picks the top K tokens amongst a batch of logits
# logits: [Tensor] list of logits
# candidates: [(batch, token)] list, where batch indicates the index of the logits the given token is from
def top_k_logits_list( logits_list, k ):
# ( batch, tokens ) => ( batch x tokens )
logits = torch.cat( logits_list )
candidates = list(torch.topk(logits.flatten(), k).indices.tolist()) # perform top-k across all logits
for i, index in enumerate(candidates):
t = []
N = np.prod(logits.size())
for n in logits.size():
N //= n
t.append(index // N)
index %= N
candidates[i] = tuple(t)
return candidates
# automagically parses a batch-list and returns it as a list # automagically parses a batch-list and returns it as a list
class Embedding(nn.Embedding): class Embedding(nn.Embedding):
@ -128,7 +144,7 @@ class Embedding(nn.Embedding):
return super().forward(torch.cat(x_list)).split([*map(len, x_list)]) return super().forward(torch.cat(x_list)).split([*map(len, x_list)])
class MultiEmbedding(nn.Embedding): class MultiEmbedding(nn.Module):
""" """
This embedding sums embeddings on different levels. This embedding sums embeddings on different levels.
""" """
@ -468,21 +484,13 @@ class Base(nn.Module):
# do beam search (naive implementation) # do beam search (naive implementation)
# picks the top-k across all batches, and re-batches those resultant tokens # picks the top-k across all batches, and re-batches those resultant tokens
# this doesn't do any other mumbo with previous logits # returns the logit scores as well to be P-concatted with the previous scores
# to-do: not naively implement beam searching # to-do: not naively implement beam searching
if beam_width > 1: if beam_width > 1:
# ( batch, tokens ) => ( batch x tokens ) candidates = top_k_logits_list( logits, beam_width )
flattened = torch.cat( logits ) res = [ torch.tensor(token, device=logits[batch].device, dtype=torch.int16).unsqueeze(dim=-1) for batch, token in candidates ]
candidates = list(torch.topk(flattened.flatten(), beam_width).indices.tolist()) # perform top-k across all logits scores = [ logits[batch].flatten()[token] for batch, token in candidates ]
for i, index in enumerate(candidates): return res, scores
t = []
N = np.prod(flattened.size())
for n in flattened.size():
N //= n
t.append(index // N)
index %= N
candidates[i] = tuple(t)
return [ torch.tensor(token, device=logits[batch].device, dtype=torch.int16).unsqueeze(dim=-1) for batch, token in candidates ] #, [ logits[batch] for batch, token in candidates ]
# and sample # and sample
# the original implementation used this instead of argmax; it's probably placebo but it performs better than argmax # the original implementation used this instead of argmax; it's probably placebo but it performs better than argmax

View File

@ -152,10 +152,10 @@ def run_eval(engines, disabled_engines, eval_name, dl):
stats = {k: sum(v) / len(v) for k, v in stats.items()} stats = {k: sum(v) / len(v) for k, v in stats.items()}
engines_stats.update({ f'{name}.{eval_name}': stats }) engines_stats = {
f'{name}.{eval_name}': stats,
iteration = engines.global_step "it": engines.global_step,
engines_stats['it'] = iteration }
#engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl) #engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.") _logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")