added picking final candidate based on sum of score instead of first candidate (this changes nothing).
This commit is contained in:
parent
23a5fdd645
commit
4aef798135
|
@ -121,12 +121,14 @@ class AR(Base):
|
|||
stopped = torch.zeros(batch_size, device=device).bool()
|
||||
|
||||
state = {} if cfg.inference.recurrent_forward else None
|
||||
sampling_beam_width_use_logs = True
|
||||
scores = [ 1.0 ] * sampling_beam_width
|
||||
|
||||
if self.interleave:
|
||||
max_steps *= self.n_prom_levels
|
||||
|
||||
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
|
||||
# get next in sequence
|
||||
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
|
||||
|
||||
logits = super().forward(
|
||||
text_list=text_list,
|
||||
|
@ -149,12 +151,23 @@ class AR(Base):
|
|||
beam_width=sampling_beam_width,
|
||||
)
|
||||
|
||||
# first step, expand batch
|
||||
# we do it here because the sampler will already expand our logits list
|
||||
if sampling_beam_width > 0 and batch_size == 1:
|
||||
if sampling_beam_width > 0:
|
||||
# expand tuple
|
||||
r, s = r
|
||||
# first step, expand batch
|
||||
if batch_size == 1:
|
||||
batch_size *= sampling_beam_width
|
||||
text_list = text_list * sampling_beam_width
|
||||
proms_list = proms_list * sampling_beam_width
|
||||
resps_list = resps_list * sampling_beam_width
|
||||
sequence_list = sequence_list * sampling_beam_width
|
||||
stopped = torch.zeros(batch_size, device=device).bool()
|
||||
|
||||
# update scores
|
||||
if sampling_beam_width_use_logs:
|
||||
scores = [ (math.log(scores[i]) if scores[i] > 0 else 0) + math.log(score) for i, score in enumerate(s) ]
|
||||
else:
|
||||
scores = [ scores[i] * score for i, score in enumerate(s) ]
|
||||
|
||||
# append tokens
|
||||
for i, ri in enumerate(r):
|
||||
|
@ -168,6 +181,16 @@ class AR(Base):
|
|||
if stopped.all().item():
|
||||
break
|
||||
|
||||
# pick the best scoring candidate
|
||||
# desu this is always going to be candidate 0
|
||||
if sampling_beam_width and len(scores) > 0:
|
||||
best_idx, best_score = (0, 0)
|
||||
for idx, score in enumerate(scores):
|
||||
if best_score > score:
|
||||
best_idx, best_score = idx, score
|
||||
|
||||
sequence_list = [sequence_list[best_idx]]
|
||||
|
||||
res = [self._prune(r) for r in resps_list]
|
||||
if self.interleave:
|
||||
res = [self._deinterleave(r) for r in res]
|
||||
|
|
|
@ -5,6 +5,7 @@ import torch
|
|||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
import random
|
||||
import math
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
from tqdm import trange
|
||||
|
@ -151,12 +152,14 @@ class AR_NAR(Base):
|
|||
|
||||
state = {} if cfg.inference.recurrent_forward else None
|
||||
|
||||
sampling_beam_width_use_logs = True
|
||||
scores = [ 1.0 ] * sampling_beam_width
|
||||
|
||||
if self.interleave:
|
||||
max_steps *= self.n_prom_levels
|
||||
|
||||
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
|
||||
# get next in sequence
|
||||
|
||||
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
|
||||
resps_list = self._unsqueeze_list(sequence_list)
|
||||
logits = super().forward(
|
||||
text_list=text_list,
|
||||
|
@ -179,15 +182,24 @@ class AR_NAR(Base):
|
|||
beam_width=sampling_beam_width,
|
||||
)
|
||||
|
||||
# first step, expand batch
|
||||
# we do it here because the sampler will already expand our logits list
|
||||
if sampling_beam_width > 0 and batch_size == 1:
|
||||
if sampling_beam_width > 0:
|
||||
# expand tuple
|
||||
r, s = r
|
||||
# first step, expand batch
|
||||
if batch_size == 1:
|
||||
batch_size *= sampling_beam_width
|
||||
text_list = text_list * sampling_beam_width
|
||||
proms_list = proms_list * sampling_beam_width
|
||||
sequence_list = sequence_list * sampling_beam_width
|
||||
stopped = torch.zeros(batch_size, device=device).bool()
|
||||
|
||||
# update scores
|
||||
if sampling_beam_width_use_logs:
|
||||
scores = [ (math.log(scores[i]) if scores[i] > 0 else 0) + math.log(score) for i, score in enumerate(s) ]
|
||||
else:
|
||||
scores = [ scores[i] * score for i, score in enumerate(s) ]
|
||||
|
||||
# append tokens
|
||||
for i, ri in enumerate(r):
|
||||
if self.stop_token in ri:
|
||||
|
@ -199,9 +211,15 @@ class AR_NAR(Base):
|
|||
if stopped.all().item():
|
||||
break
|
||||
|
||||
# pick the first candidate
|
||||
if sampling_beam_width:
|
||||
sequence_list = sequence_list[:1]
|
||||
# pick the best scoring candidate
|
||||
# desu this is always going to be candidate 0
|
||||
if sampling_beam_width and len(scores) > 0:
|
||||
best_idx, best_score = (0, 0)
|
||||
for idx, score in enumerate(scores):
|
||||
if best_score > score:
|
||||
best_idx, best_score = idx, score
|
||||
|
||||
sequence_list = [sequence_list[best_idx]]
|
||||
|
||||
return [self._prune(r) for r in sequence_list]
|
||||
|
||||
|
|
|
@ -119,6 +119,22 @@ def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf"
|
|||
|
||||
return logits
|
||||
|
||||
# picks the top K tokens amongst a batch of logits
|
||||
# logits: [Tensor] list of logits
|
||||
# candidates: [(batch, token)] list, where batch indicates the index of the logits the given token is from
|
||||
def top_k_logits_list( logits_list, k ):
|
||||
# ( batch, tokens ) => ( batch x tokens )
|
||||
logits = torch.cat( logits_list )
|
||||
candidates = list(torch.topk(logits.flatten(), k).indices.tolist()) # perform top-k across all logits
|
||||
for i, index in enumerate(candidates):
|
||||
t = []
|
||||
N = np.prod(logits.size())
|
||||
for n in logits.size():
|
||||
N //= n
|
||||
t.append(index // N)
|
||||
index %= N
|
||||
candidates[i] = tuple(t)
|
||||
return candidates
|
||||
|
||||
# automagically parses a batch-list and returns it as a list
|
||||
class Embedding(nn.Embedding):
|
||||
|
@ -128,7 +144,7 @@ class Embedding(nn.Embedding):
|
|||
|
||||
return super().forward(torch.cat(x_list)).split([*map(len, x_list)])
|
||||
|
||||
class MultiEmbedding(nn.Embedding):
|
||||
class MultiEmbedding(nn.Module):
|
||||
"""
|
||||
This embedding sums embeddings on different levels.
|
||||
"""
|
||||
|
@ -468,21 +484,13 @@ class Base(nn.Module):
|
|||
|
||||
# do beam search (naive implementation)
|
||||
# picks the top-k across all batches, and re-batches those resultant tokens
|
||||
# this doesn't do any other mumbo with previous logits
|
||||
# returns the logit scores as well to be P-concatted with the previous scores
|
||||
# to-do: not naively implement beam searching
|
||||
if beam_width > 1:
|
||||
# ( batch, tokens ) => ( batch x tokens )
|
||||
flattened = torch.cat( logits )
|
||||
candidates = list(torch.topk(flattened.flatten(), beam_width).indices.tolist()) # perform top-k across all logits
|
||||
for i, index in enumerate(candidates):
|
||||
t = []
|
||||
N = np.prod(flattened.size())
|
||||
for n in flattened.size():
|
||||
N //= n
|
||||
t.append(index // N)
|
||||
index %= N
|
||||
candidates[i] = tuple(t)
|
||||
return [ torch.tensor(token, device=logits[batch].device, dtype=torch.int16).unsqueeze(dim=-1) for batch, token in candidates ] #, [ logits[batch] for batch, token in candidates ]
|
||||
candidates = top_k_logits_list( logits, beam_width )
|
||||
res = [ torch.tensor(token, device=logits[batch].device, dtype=torch.int16).unsqueeze(dim=-1) for batch, token in candidates ]
|
||||
scores = [ logits[batch].flatten()[token] for batch, token in candidates ]
|
||||
return res, scores
|
||||
|
||||
# and sample
|
||||
# the original implementation used this instead of argmax; it's probably placebo but it performs better than argmax
|
||||
|
|
|
@ -152,10 +152,10 @@ def run_eval(engines, disabled_engines, eval_name, dl):
|
|||
|
||||
|
||||
stats = {k: sum(v) / len(v) for k, v in stats.items()}
|
||||
engines_stats.update({ f'{name}.{eval_name}': stats })
|
||||
|
||||
iteration = engines.global_step
|
||||
engines_stats['it'] = iteration
|
||||
engines_stats = {
|
||||
f'{name}.{eval_name}': stats,
|
||||
"it": engines.global_step,
|
||||
}
|
||||
#engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
|
||||
|
||||
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
|
||||
|
|
Loading…
Reference in New Issue
Block a user