added experimental entropix sampling support

This commit is contained in:
mrq 2024-10-11 21:18:26 -05:00
parent 85d85c1351
commit bef43a0c18
10 changed files with 337 additions and 74 deletions

View File

@ -237,6 +237,8 @@ class ModelExperimentalSettings:
p_len_train: float = 0.05 # odds of injecting a "len" task within the model for NAR-len p_len_train: float = 0.05 # odds of injecting a "len" task within the model for NAR-len
# to-to: just incorporate this as a task instead # to-to: just incorporate this as a task instead
entropix_sampling: bool = False # experimental sampling based on https://github.com/xjdr-alt/entropix, experimental flag because it requires using naive attention for output scores
# I really need to clean this up # I really need to clean this up
@dataclass() @dataclass()
class Model: class Model:

View File

@ -1458,7 +1458,7 @@ def process_artifact_metadata( artifact ):
metadata["similar"] = artifact["metadata"]["similar"] metadata["similar"] = artifact["metadata"]["similar"]
# duration for use of culling / sorting dataset # duration for use of culling / sorting dataset
if "duration" in artifact["metadata"]: if "duration" in artifact["metadata"]:
metadata["duration"] = duration metadata["duration"] = float(artifact["metadata"]["duration"])
# derive duration from sample count / sample rate # derive duration from sample count / sample rate
elif "original_length" in artifact["metadata"] and "sample_rate" in artifact["metadata"]: elif "original_length" in artifact["metadata"] and "sample_rate" in artifact["metadata"]:
metadata["duration"] = artifact["metadata"]["original_length"] / artifact["metadata"]["sample_rate"] metadata["duration"] = artifact["metadata"]["original_length"] / artifact["metadata"]["sample_rate"]

View File

@ -184,7 +184,7 @@ def main():
extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else ([ out_path_lora ] if args.lora else []) extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else ([ out_path_lora ] if args.lora else [])
if not args.random_prompts: if not args.random_prompts or k == "librispeech":
extra_sources += [ reference ] extra_sources += [ reference ]
samples.append(( samples.append((

View File

@ -196,10 +196,13 @@ def process(
for filename in sorted(metadata.keys()): for filename in sorted(metadata.keys()):
inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{filename}') inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{filename}')
"""
if not inpath.exists(): if not inpath.exists():
missing["audio"].append(str(inpath)) missing["audio"].append(str(inpath))
continue continue
"""
extension = os.path.splitext(filename)[-1][1:] extension = os.path.splitext(filename)[-1][1:]
fname = filename.replace(f'.{extension}', "") fname = filename.replace(f'.{extension}', "")
@ -220,10 +223,19 @@ def process(
jobs.append(( outpath, waveform, sample_rate, text, language )) jobs.append(( outpath, waveform, sample_rate, text, language ))
else: else:
i = 0 i = 0
presliced = not inpath.exists()
for segment in metadata[filename]["segments"]: for segment in metadata[filename]["segments"]:
id = pad(i, 4) id = pad(i, 4)
i = i + 1 i = i + 1
if presliced:
inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{fname}_{id}.{extension}')
if not inpath.exists():
missing["audio"].append(str(inpath))
continue
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}_{id}.{extension}').with_suffix(audio_extension) outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}_{id}.{extension}').with_suffix(audio_extension)
text = segment["text"] text = segment["text"]
@ -234,18 +246,19 @@ def process(
if waveform is None: if waveform is None:
waveform, sample_rate = load_audio( inpath ) waveform, sample_rate = load_audio( inpath )
start = int(segment['start'] * sample_rate) start = int((segment['start']-0.05) * sample_rate)
end = int(segment['end'] * sample_rate) end = int((segment['end']+0.5) * sample_rate)
if start < 0: if not presliced:
start = 0 if start < 0:
if end >= waveform.shape[-1]: start = 0
end = waveform.shape[-1] - 1 if end >= waveform.shape[-1]:
end = waveform.shape[-1] - 1
if end - start < 0: if end - start < 0:
continue continue
jobs.append(( outpath, waveform[:, start:end], sample_rate, text, language )) jobs.append(( outpath, waveform if presliced else waveform[:, start:end], sample_rate, text, language ))
# processes audio files one at a time # processes audio files one at a time
if low_memory: if low_memory:
@ -287,6 +300,11 @@ def main():
args.stride_offset = int(args.device) args.stride_offset = int(args.device)
args.device = f'cuda:{args.device}' args.device = f'cuda:{args.device}'
if args.slice == "true":
args.slice = True
elif args.slice == "false":
args.slice = False
process( process(
audio_backend=args.audio_backend, audio_backend=args.audio_backend,
input_audio=args.input_audio, input_audio=args.input_audio,

View File

@ -188,18 +188,13 @@ class AR(Base):
quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ] quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ]
) )
if state is not None: output = super().forward(
logits, state = super().forward( inputs=inputs,
inputs=inputs, state=state,
state=state, )
) logits, state = output.logits, output.state
else:
logits = super().forward(
inputs=inputs,
state=state,
)
r = super().sample( sampled = super().sample(
logits=logits, logits=logits,
prev_list=resps_list, prev_list=resps_list,
@ -219,15 +214,13 @@ class AR(Base):
dry_allowed_length=sampling_dry_allowed_length, dry_allowed_length=sampling_dry_allowed_length,
) )
r = sampled[0]
if mirostat is not None: if mirostat is not None:
# r is the state mirostat = sampled.scores
mirostat = r
# extract token from state
r = [ state["token"] for state in mirostat ]
# we do it here because the sampler will already expand our logits list
elif sampling_beam_width > 0: elif sampling_beam_width > 0:
# expand tuple # expand tuple
r, s = r scores = sampled.scores
# first step, expand batch # first step, expand batch
if batch_size == 1: if batch_size == 1:
batch_size = sampling_beam_width batch_size = sampling_beam_width
@ -236,7 +229,7 @@ class AR(Base):
sequence_list = sequence_list * sampling_beam_width sequence_list = sequence_list * sampling_beam_width
stopped = torch.zeros(batch_size, device=device).bool() stopped = torch.zeros(batch_size, device=device).bool()
scores = [ scores[i] + score for i, score in enumerate(s) ] scores = [ scores[i] + score for i, score in enumerate(scores) ]
# append tokens # append tokens
for i, ri in enumerate(r): for i, ri in enumerate(r):

View File

@ -64,6 +64,8 @@ class AR_NAR(Base):
sampling_dry_base=1.75, sampling_dry_base=1.75,
sampling_dry_allowed_length=2, sampling_dry_allowed_length=2,
sampling_entropix=None,
disable_tqdm=False, disable_tqdm=False,
use_lora=None, use_lora=None,
): ):
@ -222,11 +224,9 @@ class AR_NAR(Base):
inputs=inputs, inputs=inputs,
quant_levels=quant_levels, quant_levels=quant_levels,
) )
if not isinstance( output, tuple ): logits, state = output.logits, output.state
output = (output, None)
logits, state = output
resps_list = super().sample( sampled = super().sample(
logits=logits, logits=logits,
prev_list=prev_list, prev_list=prev_list,
quant_levels=quant_levels, quant_levels=quant_levels,
@ -242,6 +242,8 @@ class AR_NAR(Base):
#mirostat=mirostat, #mirostat=mirostat,
) )
resps_list = sampled[0]
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device=device)], dim=-1) for rs, r in zip(prev_list, resps_list) ] prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device=device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
return prev_list return prev_list
@ -264,6 +266,10 @@ class AR_NAR(Base):
] * batch_size if sampling_mirostat_tau > 0.0 else None ] * batch_size if sampling_mirostat_tau > 0.0 else None
scores = [ 1.0 ] * sampling_beam_width scores = [ 1.0 ] * sampling_beam_width
entropies = []
if sampling_entropix is None:
sampling_entropix = self.config.experimental.entropix_sampling
for i, sequence in enumerate( sequence_list ): for i, sequence in enumerate( sequence_list ):
# add <bos> to text for STT # add <bos> to text for STT
@ -296,13 +302,11 @@ class AR_NAR(Base):
output = super().forward( output = super().forward(
inputs=inputs, inputs=inputs,
state=state, state=state,
output_attentions=sampling_entropix,
) )
if not isinstance( output, tuple ): logits, state = output.logits, output.state
output = (output, None)
logits, state = output
r = super().sample( sampled = super().sample(
logits=logits, logits=logits,
prev_list=None if sampling_repetition_penalty == 1.0 and sampling_length_penalty == 0.0 else [ resps_list[i] if task not in text_task else text_list[i] for i, task in enumerate( task_list ) ], prev_list=None if sampling_repetition_penalty == 1.0 and sampling_length_penalty == 0.0 else [ resps_list[i] if task not in text_task else text_list[i] for i, task in enumerate( task_list ) ],
@ -320,17 +324,20 @@ class AR_NAR(Base):
dry_multiplier=sampling_dry_multiplier, dry_multiplier=sampling_dry_multiplier,
dry_base=sampling_dry_base, dry_base=sampling_dry_base,
dry_allowed_length=sampling_dry_allowed_length, dry_allowed_length=sampling_dry_allowed_length,
attentions=output.attentions if sampling_entropix else None,
) )
r = sampled[0]
if sampled.entropy:
entropies.append( sampled.entropy )
if mirostat is not None: if mirostat is not None:
# r is the state mirostat = sampled.scores
mirostat = r
# extract token from state
r = [ state["token"] for state in mirostat ]
# we do it here because the sampler will already expand our logits list
elif sampling_beam_width > 0: elif sampling_beam_width > 0:
# expand tuple # expand tuple
r, s = r scores = sampled.scores
# first step, expand batch # first step, expand batch
if batch_size == 1: if batch_size == 1:
batch_size = sampling_beam_width batch_size = sampling_beam_width
@ -339,7 +346,7 @@ class AR_NAR(Base):
sequence_list = sequence_list * sampling_beam_width sequence_list = sequence_list * sampling_beam_width
stopped = torch.zeros(batch_size, device=device).bool() stopped = torch.zeros(batch_size, device=device).bool()
scores = [ scores[i] + score for i, score in enumerate(s) ] scores = [ scores[i] + score for i, score in enumerate(scores) ]
# append tokens # append tokens
for i, ri in enumerate(r): for i, ri in enumerate(r):
@ -354,6 +361,10 @@ class AR_NAR(Base):
if stopped.all().item(): if stopped.all().item():
break break
if entropies:
from ..plot import plot_entropies
plot_entropies( entropies )
# pick the best scoring candidate # pick the best scoring candidate
# desu this is always going to be candidate 0 # desu this is always going to be candidate 0
if sampling_beam_width: if sampling_beam_width:

View File

@ -15,8 +15,9 @@ import torch.nn.functional as F
import random import random
import numpy as np import numpy as np
import re import re
from time import perf_counter
from time import perf_counter
from collections import namedtuple
from typing import Literal, overload, Optional, Tuple from typing import Literal, overload, Optional, Tuple
from functools import partial from functools import partial
from einops import rearrange from einops import rearrange
@ -37,6 +38,9 @@ from ..emb.qnt import encode_as_embedding
# yuck, kind of needed # yuck, kind of needed
from ..data import get_task_symmap from ..data import get_task_symmap
Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions'])
Sampled = namedtuple('Sampled', ['out', 'scores', 'entropy']) # these seem more elegant than a dict
""" """
from ..utils.pattern import DelayedPatternProvider, VALLEPattern from ..utils.pattern import DelayedPatternProvider, VALLEPattern
""" """
@ -805,11 +809,15 @@ class Base(nn.Module):
inputs, inputs,
mask = None, mask = None,
position_ids = None, position_ids = None,
state = None, state = None,
output_attentions = False,
): ):
x = inputs x = inputs
m = mask.squeeze(-1).int() m = mask.squeeze(-1).int()
aux_loss = None aux_loss = None
attentions = None
# HF transformer derived model # HF transformer derived model
if self.arch_type in ["llama", "mistral", "mixtral"]: if self.arch_type in ["llama", "mistral", "mixtral"]:
@ -819,22 +827,25 @@ class Base(nn.Module):
past_key_values=state, past_key_values=state,
position_ids=position_ids, position_ids=position_ids,
use_cache=not self.training, use_cache=not self.training,
# return_dict=True, output_attentions=output_attentions,
return_dict=True,
) )
if self.n_experts > 1 and self.training: if self.n_experts > 1 and self.training:
kwargs["output_router_logits"] = True kwargs["output_router_logits"] = True
t = self.model(**kwargs) output = self.model(**kwargs)
x = output["last_hidden_state"]
x = t[0]
# to-do: figure out why KV caching doesn't work # to-do: figure out why KV caching doesn't work
#if not self.training: #if not self.training:
if state is not None: if state is not None:
state = t[1] state = output["past_key_values"]
if output_attentions:
attentions = output["attentions"]
if self.n_experts > 1 and self.training: if self.n_experts > 1 and self.training:
router_logits = t[-1] router_logits = output["aux_loss"]
aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok ) aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok )
elif self.arch_type == "transformer": elif self.arch_type == "transformer":
# ensures we specify a quant_level for the transformer implementation's AdaLN # ensures we specify a quant_level for the transformer implementation's AdaLN
@ -895,7 +906,7 @@ class Base(nn.Module):
if self.classifier is not None: if self.classifier is not None:
x = self.classifier(x) * mask x = self.classifier(x) * mask
return x, state, aux_loss return Logits(x, state, aux_loss, attentions)
# takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation # takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation
def inputs( def inputs(
@ -1390,6 +1401,7 @@ class Base(nn.Module):
quant_levels: int | list[int] | Tensor | None = None, quant_levels: int | list[int] | Tensor | None = None,
state: dict | list | None = None, state: dict | list | None = None,
output_attentions = False,
): ):
x_list = self.inputs_to_embeddings( inputs, quant_levels ) x_list = self.inputs_to_embeddings( inputs, quant_levels )
x, m = list_to_tensor(x_list) x, m = list_to_tensor(x_list)
@ -1420,32 +1432,36 @@ class Base(nn.Module):
# needs to be done here as we still have our raw inputs # needs to be done here as we still have our raw inputs
position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None
x, state, aux_loss = self._forward( output = self._forward(
inputs=x, inputs=x,
mask=m, mask=m,
state=state, state=state,
position_ids=position_ids, position_ids=position_ids,
output_attentions = output_attentions,
) )
logits = output.logits
# to-do: piece-wise classification, now that there's a head for text # to-do: piece-wise classification, now that there's a head for text
# although again, one single monolithic head would be preferable instead...... # although again, one single monolithic head would be preferable instead......
if self.classifiers is not None: if self.classifiers is not None:
special_tasks = [ "len", "stt" ] special_tasks = [ "len", "stt" ]
classifier_quant_levels = [ -1 if inputs[i][0][-1] in special_tasks else l for i, l in enumerate( quant_levels ) ] classifier_quant_levels = [ -1 if inputs[i][0][-1] in special_tasks else l for i, l in enumerate( quant_levels ) ]
x = self.classifiers(x, levels = classifier_quant_levels) * m logits = self.classifiers(logits, levels = classifier_quant_levels) * m
# Remove padding # Remove padding
logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ] logits = [ hi[:li] for hi, li in zip(logits, map(len, x_list)) ]
# compute loss if the target is given # compute loss if the target is given
if training: if training:
self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels ) self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
# include any additional losses (for example: MoE router) # include any additional losses (for example: MoE router)
if aux_loss is not None: if output.aux_loss is not None:
self.loss["aux_loss"] = aux_loss self.loss["aux_loss"] = output.aux_loss
return (logits, state) if state is not None else logits # rewrap, because we're modifying the logits here
return Logits(logits, output.state, output.aux_loss, output.attentions)
def sample( def sample(
self, self,
@ -1470,10 +1486,15 @@ class Base(nn.Module):
dry_multiplier=0.0, dry_multiplier=0.0,
dry_base=1.75, dry_base=1.75,
dry_allowed_length=2, dry_allowed_length=2,
# other
attentions=None,
): ):
if min_temperature < 0: if min_temperature < 0:
min_temperature = temperature min_temperature = temperature
scores = None
entropy = None
# (NAR) return the entire generated response # (NAR) return the entire generated response
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously) # Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
if quant_levels is not None: # and "nar" in self.capabilities: # for when I get around to coping about dropping the NAR entirely if quant_levels is not None: # and "nar" in self.capabilities: # for when I get around to coping about dropping the NAR entirely
@ -1482,9 +1503,114 @@ class Base(nn.Module):
elif self.causal: elif self.causal:
logits = [ logit[-self.causal_size:] for logit in logits ] logits = [ logit[-self.causal_size:] for logit in logits ]
# this might actually slow things down a bit slightly-er? # calculate entropies
#logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ] # I would love to shove it in samplers.py but we modify our sampler settings
if attentions is not None:
entropy = [ calculate_entropix_metrics( logit, attn ) for logit, attn in zip(logits, attentions) ]
# this might actually slow things down a bit slightly-er?
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
# to-do: not make it hardcoded to bsz=1
metrics = entropy[0]
logit = logits[0]
ent, vent = metrics["logits_entropy"], metrics["logits_varentropy"]
attn_ent, attn_vent = metrics["attn_entropy"], metrics["attn_varentropy"]
agreement = metrics["agreement"]
interaction_strength = metrics["interaction_strength"]
# adjust sample settings
cfg = EntropixSamplerConfig()
# Low Entropy, Low Varentropy: "flowing with unspoken intent"
if ent < cfg.low_ent_thresh and vent < cfg.low_vent_thresh:
entropy[0]["action"] = 0
temperature *= 0
# High Entropy, Low Varentropy: "treading carefully, asking clarifying questions"
elif ent > cfg.high_ent_thresh and vent < cfg.low_vent_thresh:
entropy[0]["action"] = 1
# sample with slightly higher temperature
temperature *= cfg.helv_attn_ent_offset + cfg.helv_attn_ent_coef * attn_ent # Increase temperature based on attention entropy
# Low Entropy, High Varentropy: "exploring forks in the path"
elif ent < cfg.high_ent_thresh and vent > cfg.high_vent_thresh:
entropy[0]["action"] = 2
temperature *= cfg.lehv_interaction_strength_offset + cfg.lehv_interaction_strength_coef * interaction_strength # Increase temperature based on interaction strength
top_k = max(5, int(top_k * (1 + 0.5 * (1 - agreement)))) # Increase top_k when agreement is low
# High Entropy, High Varentropy: "resampling in the mist"
elif ent > cfg.med_ent_thresh and vent > cfg.high_vent_thresh:
entropy[0]["action"] = 3
# Use high temperature and adjusted top_p based on attention metrics
temperature *= cfg.hehv_attn_vent_offset + cfg.hehv_attn_vent_coef * attn_vent # Increase temperature based on attention varentropy
top_p = max(0.5, top_p - cfg.hehv_attn_ent_coef * attn_ent) # Decrease top_p when attention entropy is high
# Middle ground: use adaptive sampling
else:
entropy[0]["action"] = 4
log_softmax = torch.nn.functional.log_softmax(logit)
logits_uncertainty = ent + vent
attn_uncertainty = attn_ent + attn_vent
temperature = temperature * float(1 + cfg.ada_temp_logits * logits_uncertainty + cfg.ada_temp_attn * attn_uncertainty - cfg.ada_temp_agree * agreement)
top_p = torch.clip(top_p * (1 + cfg.ada_top_p * attn_vent), min=0.1, max=1.0).item()
top_k = int(torch.clip(
torch.round(top_k * (1 + cfg.ada_top_k_int * interaction_strength - cfg.ada_top_k_agree * agreement)),
min=cfg.top_k_min,
max=cfg.top_k_max
))
min_p = torch.clip(cfg.min_p * (1 - cfg.ada_min_p * logits_uncertainty), 0.01, 0.5)
def _sample( logits ):
# perform repetition penalizing
if "len" not in self.capabilities and prev_list is not None and repetition_penalty != 1.0:
# to-do: figure out a faster way to handle tolist()
logits = [ reptition_penalize(logit, previous=prevs[:, -1].tolist() if prevs.dim() > 1 else prevs.tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
# (AR) perform length penalizing
if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0:
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ]
# perform top_k/top_p filtering of our logits
if top_k > 0 or top_p < 1.0:
logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ]
# trigger dynamic temperature sampling if the minimum temperature is not the same as the sampling temperature
# epsilon float comparison because I don't trust Python
if abs(temperature - min_temperature) >= 0.001:
logits = [ dynamic_temperature(logit, temperature=temperature, min_temperature=min_temperature) for logit in logits ]
else:
logits = [ logit / temperature for logit in logits ]
# do DRY sampling
if dry_multiplier > 0.0:
logits = [ dry_sampling(logit, previous=resps[:, -1].tolist(), factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, resps in zip( logits, prev_list ) ]
return [ Categorical(logits=logit).sample() for logit in logits ]
samples = [ _sample([ logit.clone() for logit in logits ]) for _ in range(cfg.n_adaptive_samples) ]
def score_sample(sample):
one_hot = torch.nn.functional.one_hot(sample[0], logit.shape[-1])
log_prob = torch.sum(log_softmax * one_hot)
confidence_score = (
(1 - ent) * cfg.ada_score_logits_ent +
(1 - attn_ent) * cfg.ada_score_attn_ent +
(1 - vent) * cfg.ada_score_logits_vent +
(1 - attn_vent) * cfg.ada_score_attn_vent +
agreement * cfg.ada_score_agree +
interaction_strength * cfg.ada_score_int
)
return log_prob + confidence_score
sample_scores = [ score_sample(sample) for sample in samples ]
best_sample_idx = torch.argmax(torch.asarray(sample_scores))
res = samples[best_sample_idx]
scores = sample_scores
return Sampled(res, scores, entropy)
temperature = min(1.5, float(temperature))
# (NAR) disable stop token # (NAR) disable stop token
if quant_levels is not None and "ar" in self.capabilities: if quant_levels is not None and "ar" in self.capabilities:
logits = [ ban_tokens(logit, tokens=[self.stop_token]) for logit, l in zip( logits, map(len, prev_list) ) ] logits = [ ban_tokens(logit, tokens=[self.stop_token]) for logit, l in zip( logits, map(len, prev_list) ) ]
@ -1494,7 +1620,9 @@ class Base(nn.Module):
# argmax instead # argmax instead
if temperature <= 0.0: if temperature <= 0.0:
return [ logit.argmax(dim=1) for logit in logits ] res = [ logit.argmax(dim=1) for logit in logits ]
scores = None
return Sampled(res, scores, entropy)
# perform repetition penalizing # perform repetition penalizing
if "len" not in self.capabilities and prev_list is not None and repetition_penalty != 1.0: if "len" not in self.capabilities and prev_list is not None and repetition_penalty != 1.0:
@ -1524,17 +1652,18 @@ class Base(nn.Module):
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work # currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work
if mirostat is not None: if mirostat is not None:
# mirostat sampling # mirostat sampling
return [ mirostat_sample(logit, state=state) for logit, state in zip(logits, mirostat) ] scores = [ mirostat_sample(logit, state=state) for logit, state in zip(logits, mirostat) ]
res = [ state["token"] for state in scores ]
# do beam search (naive implementation) # do beam search (naive implementation)
# picks the top-k across all batches, and re-batches those resultant tokens # picks the top-k across all batches, and re-batches those resultant tokens
# returns the logit scores as well to be P-concatted with the previous scores # returns the logit scores as well to be P-concatted with the previous scores
# to-do: not naively implement beam searching # to-do: not naively implement beam searching
if beam_width > 1: elif beam_width > 1:
candidates = top_k_logits_list( logits, beam_width ) candidates = top_k_logits_list( logits, beam_width )
res = [ torch.tensor(token, dtype=torch.int16).unsqueeze(dim=-1) for batch, token in candidates ] res = [ torch.tensor(token, dtype=torch.int16).unsqueeze(dim=-1) for batch, token in candidates ]
scores = [ logits[batch].flatten()[token] for batch, token in candidates ] scores = [ logits[batch].flatten()[token] for batch, token in candidates ]
return res, scores # basic sampling
else:
res = [ Categorical(logits=logit).sample() for logit in logits ]
# and sample return Sampled(res, scores, entropy)
return [ Categorical(logits=logit).sample() for logit in logits ]

View File

@ -172,16 +172,17 @@ class NAR(Base):
quant_levels=quant_levels, quant_levels=quant_levels,
) )
logits = super().forward( output = super().forward(
inputs=inputs, inputs=inputs,
quant_levels=quant_levels, quant_levels=quant_levels,
) )
logits = output.logits
""" """
resps_list = [ logit[-l:].argmax(dim=1) for logit, l in zip(logits, len_list) ] resps_list = [ logit[-l:].argmax(dim=1) for logit, l in zip(logits, len_list) ]
""" """
resps_list = super().sample( sampled = super().sample(
logits=logits, logits=logits,
prev_list=prev_list, prev_list=prev_list,
quant_levels=quant_levels, quant_levels=quant_levels,
@ -196,6 +197,7 @@ class NAR(Base):
#beam_width=sampling_beam_width, #beam_width=sampling_beam_width,
#mirostat=mirostat, #mirostat=mirostat,
) )
resps_list = sampled[0]
if n == 0: if n == 0:
prev_list = [ r.unsqueeze(-1).to(device) for r in resps_list ] prev_list = [ r.unsqueeze(-1).to(device) for r in resps_list ]
@ -225,9 +227,10 @@ class NAR(Base):
quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ] quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ]
) )
logits = super().forward( output = super().forward(
inputs=inputs, inputs=inputs,
) )
logits = output.logits
r = [ logit[-1:].argmax(dim=1) for logit in logits ] r = [ logit[-1:].argmax(dim=1) for logit in logits ]
# sanitize # sanitize

View File

@ -93,6 +93,29 @@ def plot(paths, args):
#bbox_to_anchor=(1.04, 0.5), #bbox_to_anchor=(1.04, 0.5),
) )
def plot_entropies( entropies ):
"""
fig = plt.figure()
fig.set_figwidth( 16 * len(entropies) // cfg.dataset.frames_per_second )
"""
data = {}
for key in entropies[0][0].keys():
data[key] = [ e[0][key].item() if hasattr( e[0][key], "item" ) else e[0][key] for e in entropies ]
df = pd.DataFrame(data)
df.plot()
plt.gca().legend(
#loc="center left",
fancybox=True,
shadow=True,
#bbox_to_anchor=(1.04, 0.5),
)
out_path = cfg.rel_path / "metrics.png"
plt.savefig(out_path, bbox_inches="tight")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View File

@ -5,6 +5,8 @@ import numpy as np
from torch import Tensor, einsum, nn from torch import Tensor, einsum, nn
from dataclasses import asdict, dataclass, field
# Simple filter to modify a token's probability if it shows up in the past # Simple filter to modify a token's probability if it shows up in the past
# `one_time` will only apply the penalty once # `one_time` will only apply the penalty once
# `decay` is a factor that will exponentially apply to how far away it is # `decay` is a factor that will exponentially apply to how far away it is
@ -201,4 +203,86 @@ def dry_sampling( logits, previous=None, factor=0.0, base=1.75, allowed_length=2
break break
logits[:, token] -= factor * base ** (length - allowed_length) logits[:, token] -= factor * base ** (length - allowed_length)
return logits return logits
LN_2 = 0.69314718056 # ln(2) = 1.0 / LOG2_E
# Grabbed from https://github.com/xjdr-alt/entropix/blob/main/entropix/sampler.py
# Right now I only care about quantifying these two, I'll figure out how to best apply this to the model
def calculate_entropix_metrics( logits, attention_scores=None, dim=-1 ):
"""Calculate the entropy and varentropy of the probability distribution using logsoftmax."""
log_probs = torch.nn.functional.log_softmax(logits, dim=dim)
probs = torch.exp(log_probs)
entropy = -torch.sum(probs * log_probs, dim=dim) / LN_2 # Convert to base-2
varentropy = torch.sum(probs * (log_probs / LN_2 + entropy[..., None])**2, dim=dim)
if attention_scores is None:
return {
"logits_entropy": torch.mean(entropy).item(),
"logits_varentropy": torch.mean(varentropy).item(),
}
attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1)
attn_entropy = -torch.sum(attention_probs * torch.log2(torch.clip(attention_probs, 1e-10, 1.0)), dim=-1)
attn_varentropy = torch.var(attn_entropy, dim=1)
mean_attention = torch.mean(attention_probs, dim=1)
agreement = torch.mean(torch.abs(attention_probs - mean_attention[:, None, :]), dim=(1, 2))
interaction_strength = torch.mean(torch.abs(attention_scores), dim=(1, 2, 3))
return {
"logits_entropy": torch.mean(entropy),
"logits_varentropy": torch.mean(varentropy),
"attn_entropy": torch.mean(attn_entropy),
"attn_varentropy": torch.mean(attn_varentropy),
"agreement": torch.mean(agreement),
"interaction_strength": torch.mean(torch.abs(attention_scores), dim=(1, 2, 3)),
}
# to-do: play around with these values
@dataclass()
class EntropixSamplerConfig:
temp: float = 0.999
top_p: float = 0.90
top_k: int = 32
min_p: float = 0.01 # was 0.03 # Turn this down to 0.01 to reduce the shoggoth
low_ent_thresh: float = 0.1
low_vent_thresh: float = 0.1
med_ent_thresh: float = 3.0
high_ent_thresh: float = 5.0
high_vent_thresh: float = 5.0
# TODO this is a bit of a nasty mess, but also makes all the hyperparameters visible
helv_attn_ent_offset: float = 1.3
helv_attn_ent_coef: float = 0.2
lehv_interaction_strength_offset: float = 1.2
lehv_interaction_strength_coef: float = 0.3
hehv_attn_ent_coef: float = 0.2
hehv_attn_vent_offset: float = 2.0
hehv_attn_vent_coef: float = 0.5
# TODO not convinced this should
n_adaptive_samples: int = 5
# Adaptive sampling parameters
ada_temp_logits: float = 0.3
ada_temp_attn: float = 0.2
ada_temp_agree: float = 0.2
ada_top_p: float = 0.1
ada_top_k_int: float = 0.3
ada_top_k_agree: float = 0.2
ada_min_p: float = 0.5
ada_score_logits_ent: float = 0.1
ada_score_attn_ent: float = 0.2
ada_score_logits_vent: float = 0.3
ada_score_attn_vent: float = 0.4
ada_score_agree: float = 0.5
ada_score_int: float = 0.6
# extra stuff
top_k_min: int = 32
top_k_max: int = 128