added picking final candidate based on sum of score instead of first candidate (this changes nothing).
This commit is contained in:
parent
23a5fdd645
commit
4aef798135
|
@ -121,12 +121,14 @@ class AR(Base):
|
||||||
stopped = torch.zeros(batch_size, device=device).bool()
|
stopped = torch.zeros(batch_size, device=device).bool()
|
||||||
|
|
||||||
state = {} if cfg.inference.recurrent_forward else None
|
state = {} if cfg.inference.recurrent_forward else None
|
||||||
|
sampling_beam_width_use_logs = True
|
||||||
|
scores = [ 1.0 ] * sampling_beam_width
|
||||||
|
|
||||||
if self.interleave:
|
if self.interleave:
|
||||||
max_steps *= self.n_prom_levels
|
max_steps *= self.n_prom_levels
|
||||||
|
|
||||||
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
|
|
||||||
# get next in sequence
|
# get next in sequence
|
||||||
|
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
|
||||||
|
|
||||||
logits = super().forward(
|
logits = super().forward(
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
|
@ -149,12 +151,23 @@ class AR(Base):
|
||||||
beam_width=sampling_beam_width,
|
beam_width=sampling_beam_width,
|
||||||
)
|
)
|
||||||
|
|
||||||
# first step, expand batch
|
|
||||||
# we do it here because the sampler will already expand our logits list
|
# we do it here because the sampler will already expand our logits list
|
||||||
if sampling_beam_width > 0 and batch_size == 1:
|
if sampling_beam_width > 0:
|
||||||
|
# expand tuple
|
||||||
|
r, s = r
|
||||||
|
# first step, expand batch
|
||||||
|
if batch_size == 1:
|
||||||
|
batch_size *= sampling_beam_width
|
||||||
text_list = text_list * sampling_beam_width
|
text_list = text_list * sampling_beam_width
|
||||||
proms_list = proms_list * sampling_beam_width
|
proms_list = proms_list * sampling_beam_width
|
||||||
resps_list = resps_list * sampling_beam_width
|
sequence_list = sequence_list * sampling_beam_width
|
||||||
|
stopped = torch.zeros(batch_size, device=device).bool()
|
||||||
|
|
||||||
|
# update scores
|
||||||
|
if sampling_beam_width_use_logs:
|
||||||
|
scores = [ (math.log(scores[i]) if scores[i] > 0 else 0) + math.log(score) for i, score in enumerate(s) ]
|
||||||
|
else:
|
||||||
|
scores = [ scores[i] * score for i, score in enumerate(s) ]
|
||||||
|
|
||||||
# append tokens
|
# append tokens
|
||||||
for i, ri in enumerate(r):
|
for i, ri in enumerate(r):
|
||||||
|
@ -168,6 +181,16 @@ class AR(Base):
|
||||||
if stopped.all().item():
|
if stopped.all().item():
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# pick the best scoring candidate
|
||||||
|
# desu this is always going to be candidate 0
|
||||||
|
if sampling_beam_width and len(scores) > 0:
|
||||||
|
best_idx, best_score = (0, 0)
|
||||||
|
for idx, score in enumerate(scores):
|
||||||
|
if best_score > score:
|
||||||
|
best_idx, best_score = idx, score
|
||||||
|
|
||||||
|
sequence_list = [sequence_list[best_idx]]
|
||||||
|
|
||||||
res = [self._prune(r) for r in resps_list]
|
res = [self._prune(r) for r in resps_list]
|
||||||
if self.interleave:
|
if self.interleave:
|
||||||
res = [self._deinterleave(r) for r in res]
|
res = [self._deinterleave(r) for r in res]
|
||||||
|
|
|
@ -5,6 +5,7 @@ import torch
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
import random
|
import random
|
||||||
|
import math
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
@ -151,12 +152,14 @@ class AR_NAR(Base):
|
||||||
|
|
||||||
state = {} if cfg.inference.recurrent_forward else None
|
state = {} if cfg.inference.recurrent_forward else None
|
||||||
|
|
||||||
|
sampling_beam_width_use_logs = True
|
||||||
|
scores = [ 1.0 ] * sampling_beam_width
|
||||||
|
|
||||||
if self.interleave:
|
if self.interleave:
|
||||||
max_steps *= self.n_prom_levels
|
max_steps *= self.n_prom_levels
|
||||||
|
|
||||||
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
|
|
||||||
# get next in sequence
|
# get next in sequence
|
||||||
|
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
|
||||||
resps_list = self._unsqueeze_list(sequence_list)
|
resps_list = self._unsqueeze_list(sequence_list)
|
||||||
logits = super().forward(
|
logits = super().forward(
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
|
@ -179,15 +182,24 @@ class AR_NAR(Base):
|
||||||
beam_width=sampling_beam_width,
|
beam_width=sampling_beam_width,
|
||||||
)
|
)
|
||||||
|
|
||||||
# first step, expand batch
|
|
||||||
# we do it here because the sampler will already expand our logits list
|
# we do it here because the sampler will already expand our logits list
|
||||||
if sampling_beam_width > 0 and batch_size == 1:
|
if sampling_beam_width > 0:
|
||||||
|
# expand tuple
|
||||||
|
r, s = r
|
||||||
|
# first step, expand batch
|
||||||
|
if batch_size == 1:
|
||||||
batch_size *= sampling_beam_width
|
batch_size *= sampling_beam_width
|
||||||
text_list = text_list * sampling_beam_width
|
text_list = text_list * sampling_beam_width
|
||||||
proms_list = proms_list * sampling_beam_width
|
proms_list = proms_list * sampling_beam_width
|
||||||
sequence_list = sequence_list * sampling_beam_width
|
sequence_list = sequence_list * sampling_beam_width
|
||||||
stopped = torch.zeros(batch_size, device=device).bool()
|
stopped = torch.zeros(batch_size, device=device).bool()
|
||||||
|
|
||||||
|
# update scores
|
||||||
|
if sampling_beam_width_use_logs:
|
||||||
|
scores = [ (math.log(scores[i]) if scores[i] > 0 else 0) + math.log(score) for i, score in enumerate(s) ]
|
||||||
|
else:
|
||||||
|
scores = [ scores[i] * score for i, score in enumerate(s) ]
|
||||||
|
|
||||||
# append tokens
|
# append tokens
|
||||||
for i, ri in enumerate(r):
|
for i, ri in enumerate(r):
|
||||||
if self.stop_token in ri:
|
if self.stop_token in ri:
|
||||||
|
@ -199,9 +211,15 @@ class AR_NAR(Base):
|
||||||
if stopped.all().item():
|
if stopped.all().item():
|
||||||
break
|
break
|
||||||
|
|
||||||
# pick the first candidate
|
# pick the best scoring candidate
|
||||||
if sampling_beam_width:
|
# desu this is always going to be candidate 0
|
||||||
sequence_list = sequence_list[:1]
|
if sampling_beam_width and len(scores) > 0:
|
||||||
|
best_idx, best_score = (0, 0)
|
||||||
|
for idx, score in enumerate(scores):
|
||||||
|
if best_score > score:
|
||||||
|
best_idx, best_score = idx, score
|
||||||
|
|
||||||
|
sequence_list = [sequence_list[best_idx]]
|
||||||
|
|
||||||
return [self._prune(r) for r in sequence_list]
|
return [self._prune(r) for r in sequence_list]
|
||||||
|
|
||||||
|
|
|
@ -119,6 +119,22 @@ def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf"
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
# picks the top K tokens amongst a batch of logits
|
||||||
|
# logits: [Tensor] list of logits
|
||||||
|
# candidates: [(batch, token)] list, where batch indicates the index of the logits the given token is from
|
||||||
|
def top_k_logits_list( logits_list, k ):
|
||||||
|
# ( batch, tokens ) => ( batch x tokens )
|
||||||
|
logits = torch.cat( logits_list )
|
||||||
|
candidates = list(torch.topk(logits.flatten(), k).indices.tolist()) # perform top-k across all logits
|
||||||
|
for i, index in enumerate(candidates):
|
||||||
|
t = []
|
||||||
|
N = np.prod(logits.size())
|
||||||
|
for n in logits.size():
|
||||||
|
N //= n
|
||||||
|
t.append(index // N)
|
||||||
|
index %= N
|
||||||
|
candidates[i] = tuple(t)
|
||||||
|
return candidates
|
||||||
|
|
||||||
# automagically parses a batch-list and returns it as a list
|
# automagically parses a batch-list and returns it as a list
|
||||||
class Embedding(nn.Embedding):
|
class Embedding(nn.Embedding):
|
||||||
|
@ -128,7 +144,7 @@ class Embedding(nn.Embedding):
|
||||||
|
|
||||||
return super().forward(torch.cat(x_list)).split([*map(len, x_list)])
|
return super().forward(torch.cat(x_list)).split([*map(len, x_list)])
|
||||||
|
|
||||||
class MultiEmbedding(nn.Embedding):
|
class MultiEmbedding(nn.Module):
|
||||||
"""
|
"""
|
||||||
This embedding sums embeddings on different levels.
|
This embedding sums embeddings on different levels.
|
||||||
"""
|
"""
|
||||||
|
@ -468,21 +484,13 @@ class Base(nn.Module):
|
||||||
|
|
||||||
# do beam search (naive implementation)
|
# do beam search (naive implementation)
|
||||||
# picks the top-k across all batches, and re-batches those resultant tokens
|
# picks the top-k across all batches, and re-batches those resultant tokens
|
||||||
# this doesn't do any other mumbo with previous logits
|
# returns the logit scores as well to be P-concatted with the previous scores
|
||||||
# to-do: not naively implement beam searching
|
# to-do: not naively implement beam searching
|
||||||
if beam_width > 1:
|
if beam_width > 1:
|
||||||
# ( batch, tokens ) => ( batch x tokens )
|
candidates = top_k_logits_list( logits, beam_width )
|
||||||
flattened = torch.cat( logits )
|
res = [ torch.tensor(token, device=logits[batch].device, dtype=torch.int16).unsqueeze(dim=-1) for batch, token in candidates ]
|
||||||
candidates = list(torch.topk(flattened.flatten(), beam_width).indices.tolist()) # perform top-k across all logits
|
scores = [ logits[batch].flatten()[token] for batch, token in candidates ]
|
||||||
for i, index in enumerate(candidates):
|
return res, scores
|
||||||
t = []
|
|
||||||
N = np.prod(flattened.size())
|
|
||||||
for n in flattened.size():
|
|
||||||
N //= n
|
|
||||||
t.append(index // N)
|
|
||||||
index %= N
|
|
||||||
candidates[i] = tuple(t)
|
|
||||||
return [ torch.tensor(token, device=logits[batch].device, dtype=torch.int16).unsqueeze(dim=-1) for batch, token in candidates ] #, [ logits[batch] for batch, token in candidates ]
|
|
||||||
|
|
||||||
# and sample
|
# and sample
|
||||||
# the original implementation used this instead of argmax; it's probably placebo but it performs better than argmax
|
# the original implementation used this instead of argmax; it's probably placebo but it performs better than argmax
|
||||||
|
|
|
@ -152,10 +152,10 @@ def run_eval(engines, disabled_engines, eval_name, dl):
|
||||||
|
|
||||||
|
|
||||||
stats = {k: sum(v) / len(v) for k, v in stats.items()}
|
stats = {k: sum(v) / len(v) for k, v in stats.items()}
|
||||||
engines_stats.update({ f'{name}.{eval_name}': stats })
|
engines_stats = {
|
||||||
|
f'{name}.{eval_name}': stats,
|
||||||
iteration = engines.global_step
|
"it": engines.global_step,
|
||||||
engines_stats['it'] = iteration
|
}
|
||||||
#engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
|
#engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
|
||||||
|
|
||||||
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
|
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user