added experimental entropix sampling support
This commit is contained in:
parent
85d85c1351
commit
bef43a0c18
|
@ -237,6 +237,8 @@ class ModelExperimentalSettings:
|
||||||
p_len_train: float = 0.05 # odds of injecting a "len" task within the model for NAR-len
|
p_len_train: float = 0.05 # odds of injecting a "len" task within the model for NAR-len
|
||||||
# to-to: just incorporate this as a task instead
|
# to-to: just incorporate this as a task instead
|
||||||
|
|
||||||
|
entropix_sampling: bool = False # experimental sampling based on https://github.com/xjdr-alt/entropix, experimental flag because it requires using naive attention for output scores
|
||||||
|
|
||||||
# I really need to clean this up
|
# I really need to clean this up
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Model:
|
class Model:
|
||||||
|
|
|
@ -1458,7 +1458,7 @@ def process_artifact_metadata( artifact ):
|
||||||
metadata["similar"] = artifact["metadata"]["similar"]
|
metadata["similar"] = artifact["metadata"]["similar"]
|
||||||
# duration for use of culling / sorting dataset
|
# duration for use of culling / sorting dataset
|
||||||
if "duration" in artifact["metadata"]:
|
if "duration" in artifact["metadata"]:
|
||||||
metadata["duration"] = duration
|
metadata["duration"] = float(artifact["metadata"]["duration"])
|
||||||
# derive duration from sample count / sample rate
|
# derive duration from sample count / sample rate
|
||||||
elif "original_length" in artifact["metadata"] and "sample_rate" in artifact["metadata"]:
|
elif "original_length" in artifact["metadata"] and "sample_rate" in artifact["metadata"]:
|
||||||
metadata["duration"] = artifact["metadata"]["original_length"] / artifact["metadata"]["sample_rate"]
|
metadata["duration"] = artifact["metadata"]["original_length"] / artifact["metadata"]["sample_rate"]
|
||||||
|
|
|
@ -184,7 +184,7 @@ def main():
|
||||||
|
|
||||||
extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else ([ out_path_lora ] if args.lora else [])
|
extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else ([ out_path_lora ] if args.lora else [])
|
||||||
|
|
||||||
if not args.random_prompts:
|
if not args.random_prompts or k == "librispeech":
|
||||||
extra_sources += [ reference ]
|
extra_sources += [ reference ]
|
||||||
|
|
||||||
samples.append((
|
samples.append((
|
||||||
|
|
|
@ -196,10 +196,13 @@ def process(
|
||||||
|
|
||||||
for filename in sorted(metadata.keys()):
|
for filename in sorted(metadata.keys()):
|
||||||
inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{filename}')
|
inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{filename}')
|
||||||
|
|
||||||
|
"""
|
||||||
if not inpath.exists():
|
if not inpath.exists():
|
||||||
missing["audio"].append(str(inpath))
|
missing["audio"].append(str(inpath))
|
||||||
continue
|
continue
|
||||||
|
"""
|
||||||
|
|
||||||
extension = os.path.splitext(filename)[-1][1:]
|
extension = os.path.splitext(filename)[-1][1:]
|
||||||
fname = filename.replace(f'.{extension}', "")
|
fname = filename.replace(f'.{extension}', "")
|
||||||
|
|
||||||
|
@ -220,10 +223,19 @@ def process(
|
||||||
jobs.append(( outpath, waveform, sample_rate, text, language ))
|
jobs.append(( outpath, waveform, sample_rate, text, language ))
|
||||||
else:
|
else:
|
||||||
i = 0
|
i = 0
|
||||||
|
presliced = not inpath.exists()
|
||||||
|
|
||||||
for segment in metadata[filename]["segments"]:
|
for segment in metadata[filename]["segments"]:
|
||||||
id = pad(i, 4)
|
id = pad(i, 4)
|
||||||
i = i + 1
|
i = i + 1
|
||||||
|
|
||||||
|
if presliced:
|
||||||
|
inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{fname}_{id}.{extension}')
|
||||||
|
|
||||||
|
if not inpath.exists():
|
||||||
|
missing["audio"].append(str(inpath))
|
||||||
|
continue
|
||||||
|
|
||||||
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}_{id}.{extension}').with_suffix(audio_extension)
|
outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}_{id}.{extension}').with_suffix(audio_extension)
|
||||||
text = segment["text"]
|
text = segment["text"]
|
||||||
|
|
||||||
|
@ -234,18 +246,19 @@ def process(
|
||||||
if waveform is None:
|
if waveform is None:
|
||||||
waveform, sample_rate = load_audio( inpath )
|
waveform, sample_rate = load_audio( inpath )
|
||||||
|
|
||||||
start = int(segment['start'] * sample_rate)
|
start = int((segment['start']-0.05) * sample_rate)
|
||||||
end = int(segment['end'] * sample_rate)
|
end = int((segment['end']+0.5) * sample_rate)
|
||||||
|
|
||||||
if start < 0:
|
if not presliced:
|
||||||
start = 0
|
if start < 0:
|
||||||
if end >= waveform.shape[-1]:
|
start = 0
|
||||||
end = waveform.shape[-1] - 1
|
if end >= waveform.shape[-1]:
|
||||||
|
end = waveform.shape[-1] - 1
|
||||||
|
|
||||||
if end - start < 0:
|
if end - start < 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
jobs.append(( outpath, waveform[:, start:end], sample_rate, text, language ))
|
jobs.append(( outpath, waveform if presliced else waveform[:, start:end], sample_rate, text, language ))
|
||||||
|
|
||||||
# processes audio files one at a time
|
# processes audio files one at a time
|
||||||
if low_memory:
|
if low_memory:
|
||||||
|
@ -287,6 +300,11 @@ def main():
|
||||||
args.stride_offset = int(args.device)
|
args.stride_offset = int(args.device)
|
||||||
args.device = f'cuda:{args.device}'
|
args.device = f'cuda:{args.device}'
|
||||||
|
|
||||||
|
if args.slice == "true":
|
||||||
|
args.slice = True
|
||||||
|
elif args.slice == "false":
|
||||||
|
args.slice = False
|
||||||
|
|
||||||
process(
|
process(
|
||||||
audio_backend=args.audio_backend,
|
audio_backend=args.audio_backend,
|
||||||
input_audio=args.input_audio,
|
input_audio=args.input_audio,
|
||||||
|
|
|
@ -188,18 +188,13 @@ class AR(Base):
|
||||||
quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ]
|
quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ]
|
||||||
)
|
)
|
||||||
|
|
||||||
if state is not None:
|
output = super().forward(
|
||||||
logits, state = super().forward(
|
inputs=inputs,
|
||||||
inputs=inputs,
|
state=state,
|
||||||
state=state,
|
)
|
||||||
)
|
logits, state = output.logits, output.state
|
||||||
else:
|
|
||||||
logits = super().forward(
|
|
||||||
inputs=inputs,
|
|
||||||
state=state,
|
|
||||||
)
|
|
||||||
|
|
||||||
r = super().sample(
|
sampled = super().sample(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
prev_list=resps_list,
|
prev_list=resps_list,
|
||||||
|
|
||||||
|
@ -219,15 +214,13 @@ class AR(Base):
|
||||||
dry_allowed_length=sampling_dry_allowed_length,
|
dry_allowed_length=sampling_dry_allowed_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
r = sampled[0]
|
||||||
|
|
||||||
if mirostat is not None:
|
if mirostat is not None:
|
||||||
# r is the state
|
mirostat = sampled.scores
|
||||||
mirostat = r
|
|
||||||
# extract token from state
|
|
||||||
r = [ state["token"] for state in mirostat ]
|
|
||||||
# we do it here because the sampler will already expand our logits list
|
|
||||||
elif sampling_beam_width > 0:
|
elif sampling_beam_width > 0:
|
||||||
# expand tuple
|
# expand tuple
|
||||||
r, s = r
|
scores = sampled.scores
|
||||||
# first step, expand batch
|
# first step, expand batch
|
||||||
if batch_size == 1:
|
if batch_size == 1:
|
||||||
batch_size = sampling_beam_width
|
batch_size = sampling_beam_width
|
||||||
|
@ -236,7 +229,7 @@ class AR(Base):
|
||||||
sequence_list = sequence_list * sampling_beam_width
|
sequence_list = sequence_list * sampling_beam_width
|
||||||
stopped = torch.zeros(batch_size, device=device).bool()
|
stopped = torch.zeros(batch_size, device=device).bool()
|
||||||
|
|
||||||
scores = [ scores[i] + score for i, score in enumerate(s) ]
|
scores = [ scores[i] + score for i, score in enumerate(scores) ]
|
||||||
|
|
||||||
# append tokens
|
# append tokens
|
||||||
for i, ri in enumerate(r):
|
for i, ri in enumerate(r):
|
||||||
|
|
|
@ -64,6 +64,8 @@ class AR_NAR(Base):
|
||||||
sampling_dry_base=1.75,
|
sampling_dry_base=1.75,
|
||||||
sampling_dry_allowed_length=2,
|
sampling_dry_allowed_length=2,
|
||||||
|
|
||||||
|
sampling_entropix=None,
|
||||||
|
|
||||||
disable_tqdm=False,
|
disable_tqdm=False,
|
||||||
use_lora=None,
|
use_lora=None,
|
||||||
):
|
):
|
||||||
|
@ -222,11 +224,9 @@ class AR_NAR(Base):
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
quant_levels=quant_levels,
|
quant_levels=quant_levels,
|
||||||
)
|
)
|
||||||
if not isinstance( output, tuple ):
|
logits, state = output.logits, output.state
|
||||||
output = (output, None)
|
|
||||||
logits, state = output
|
|
||||||
|
|
||||||
resps_list = super().sample(
|
sampled = super().sample(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
prev_list=prev_list,
|
prev_list=prev_list,
|
||||||
quant_levels=quant_levels,
|
quant_levels=quant_levels,
|
||||||
|
@ -242,6 +242,8 @@ class AR_NAR(Base):
|
||||||
#mirostat=mirostat,
|
#mirostat=mirostat,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
resps_list = sampled[0]
|
||||||
|
|
||||||
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device=device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
|
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device=device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
|
||||||
|
|
||||||
return prev_list
|
return prev_list
|
||||||
|
@ -264,6 +266,10 @@ class AR_NAR(Base):
|
||||||
] * batch_size if sampling_mirostat_tau > 0.0 else None
|
] * batch_size if sampling_mirostat_tau > 0.0 else None
|
||||||
|
|
||||||
scores = [ 1.0 ] * sampling_beam_width
|
scores = [ 1.0 ] * sampling_beam_width
|
||||||
|
entropies = []
|
||||||
|
|
||||||
|
if sampling_entropix is None:
|
||||||
|
sampling_entropix = self.config.experimental.entropix_sampling
|
||||||
|
|
||||||
for i, sequence in enumerate( sequence_list ):
|
for i, sequence in enumerate( sequence_list ):
|
||||||
# add <bos> to text for STT
|
# add <bos> to text for STT
|
||||||
|
@ -296,13 +302,11 @@ class AR_NAR(Base):
|
||||||
output = super().forward(
|
output = super().forward(
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
state=state,
|
state=state,
|
||||||
|
output_attentions=sampling_entropix,
|
||||||
)
|
)
|
||||||
if not isinstance( output, tuple ):
|
logits, state = output.logits, output.state
|
||||||
output = (output, None)
|
|
||||||
|
|
||||||
logits, state = output
|
|
||||||
|
|
||||||
r = super().sample(
|
sampled = super().sample(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
prev_list=None if sampling_repetition_penalty == 1.0 and sampling_length_penalty == 0.0 else [ resps_list[i] if task not in text_task else text_list[i] for i, task in enumerate( task_list ) ],
|
prev_list=None if sampling_repetition_penalty == 1.0 and sampling_length_penalty == 0.0 else [ resps_list[i] if task not in text_task else text_list[i] for i, task in enumerate( task_list ) ],
|
||||||
|
|
||||||
|
@ -320,17 +324,20 @@ class AR_NAR(Base):
|
||||||
dry_multiplier=sampling_dry_multiplier,
|
dry_multiplier=sampling_dry_multiplier,
|
||||||
dry_base=sampling_dry_base,
|
dry_base=sampling_dry_base,
|
||||||
dry_allowed_length=sampling_dry_allowed_length,
|
dry_allowed_length=sampling_dry_allowed_length,
|
||||||
|
|
||||||
|
attentions=output.attentions if sampling_entropix else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
r = sampled[0]
|
||||||
|
|
||||||
|
if sampled.entropy:
|
||||||
|
entropies.append( sampled.entropy )
|
||||||
|
|
||||||
if mirostat is not None:
|
if mirostat is not None:
|
||||||
# r is the state
|
mirostat = sampled.scores
|
||||||
mirostat = r
|
|
||||||
# extract token from state
|
|
||||||
r = [ state["token"] for state in mirostat ]
|
|
||||||
# we do it here because the sampler will already expand our logits list
|
|
||||||
elif sampling_beam_width > 0:
|
elif sampling_beam_width > 0:
|
||||||
# expand tuple
|
# expand tuple
|
||||||
r, s = r
|
scores = sampled.scores
|
||||||
# first step, expand batch
|
# first step, expand batch
|
||||||
if batch_size == 1:
|
if batch_size == 1:
|
||||||
batch_size = sampling_beam_width
|
batch_size = sampling_beam_width
|
||||||
|
@ -339,7 +346,7 @@ class AR_NAR(Base):
|
||||||
sequence_list = sequence_list * sampling_beam_width
|
sequence_list = sequence_list * sampling_beam_width
|
||||||
stopped = torch.zeros(batch_size, device=device).bool()
|
stopped = torch.zeros(batch_size, device=device).bool()
|
||||||
|
|
||||||
scores = [ scores[i] + score for i, score in enumerate(s) ]
|
scores = [ scores[i] + score for i, score in enumerate(scores) ]
|
||||||
|
|
||||||
# append tokens
|
# append tokens
|
||||||
for i, ri in enumerate(r):
|
for i, ri in enumerate(r):
|
||||||
|
@ -354,6 +361,10 @@ class AR_NAR(Base):
|
||||||
if stopped.all().item():
|
if stopped.all().item():
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if entropies:
|
||||||
|
from ..plot import plot_entropies
|
||||||
|
plot_entropies( entropies )
|
||||||
|
|
||||||
# pick the best scoring candidate
|
# pick the best scoring candidate
|
||||||
# desu this is always going to be candidate 0
|
# desu this is always going to be candidate 0
|
||||||
if sampling_beam_width:
|
if sampling_beam_width:
|
||||||
|
|
|
@ -15,8 +15,9 @@ import torch.nn.functional as F
|
||||||
import random
|
import random
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import re
|
import re
|
||||||
from time import perf_counter
|
|
||||||
|
|
||||||
|
from time import perf_counter
|
||||||
|
from collections import namedtuple
|
||||||
from typing import Literal, overload, Optional, Tuple
|
from typing import Literal, overload, Optional, Tuple
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
@ -37,6 +38,9 @@ from ..emb.qnt import encode_as_embedding
|
||||||
# yuck, kind of needed
|
# yuck, kind of needed
|
||||||
from ..data import get_task_symmap
|
from ..data import get_task_symmap
|
||||||
|
|
||||||
|
Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions'])
|
||||||
|
Sampled = namedtuple('Sampled', ['out', 'scores', 'entropy']) # these seem more elegant than a dict
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from ..utils.pattern import DelayedPatternProvider, VALLEPattern
|
from ..utils.pattern import DelayedPatternProvider, VALLEPattern
|
||||||
"""
|
"""
|
||||||
|
@ -805,11 +809,15 @@ class Base(nn.Module):
|
||||||
inputs,
|
inputs,
|
||||||
mask = None,
|
mask = None,
|
||||||
position_ids = None,
|
position_ids = None,
|
||||||
|
|
||||||
state = None,
|
state = None,
|
||||||
|
output_attentions = False,
|
||||||
):
|
):
|
||||||
x = inputs
|
x = inputs
|
||||||
m = mask.squeeze(-1).int()
|
m = mask.squeeze(-1).int()
|
||||||
|
|
||||||
aux_loss = None
|
aux_loss = None
|
||||||
|
attentions = None
|
||||||
|
|
||||||
# HF transformer derived model
|
# HF transformer derived model
|
||||||
if self.arch_type in ["llama", "mistral", "mixtral"]:
|
if self.arch_type in ["llama", "mistral", "mixtral"]:
|
||||||
|
@ -819,22 +827,25 @@ class Base(nn.Module):
|
||||||
past_key_values=state,
|
past_key_values=state,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
use_cache=not self.training,
|
use_cache=not self.training,
|
||||||
# return_dict=True,
|
output_attentions=output_attentions,
|
||||||
|
return_dict=True,
|
||||||
)
|
)
|
||||||
if self.n_experts > 1 and self.training:
|
if self.n_experts > 1 and self.training:
|
||||||
kwargs["output_router_logits"] = True
|
kwargs["output_router_logits"] = True
|
||||||
|
|
||||||
t = self.model(**kwargs)
|
output = self.model(**kwargs)
|
||||||
|
x = output["last_hidden_state"]
|
||||||
x = t[0]
|
|
||||||
|
|
||||||
# to-do: figure out why KV caching doesn't work
|
# to-do: figure out why KV caching doesn't work
|
||||||
#if not self.training:
|
#if not self.training:
|
||||||
if state is not None:
|
if state is not None:
|
||||||
state = t[1]
|
state = output["past_key_values"]
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
attentions = output["attentions"]
|
||||||
|
|
||||||
if self.n_experts > 1 and self.training:
|
if self.n_experts > 1 and self.training:
|
||||||
router_logits = t[-1]
|
router_logits = output["aux_loss"]
|
||||||
aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok )
|
aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok )
|
||||||
elif self.arch_type == "transformer":
|
elif self.arch_type == "transformer":
|
||||||
# ensures we specify a quant_level for the transformer implementation's AdaLN
|
# ensures we specify a quant_level for the transformer implementation's AdaLN
|
||||||
|
@ -895,7 +906,7 @@ class Base(nn.Module):
|
||||||
if self.classifier is not None:
|
if self.classifier is not None:
|
||||||
x = self.classifier(x) * mask
|
x = self.classifier(x) * mask
|
||||||
|
|
||||||
return x, state, aux_loss
|
return Logits(x, state, aux_loss, attentions)
|
||||||
|
|
||||||
# takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation
|
# takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation
|
||||||
def inputs(
|
def inputs(
|
||||||
|
@ -1390,6 +1401,7 @@ class Base(nn.Module):
|
||||||
|
|
||||||
quant_levels: int | list[int] | Tensor | None = None,
|
quant_levels: int | list[int] | Tensor | None = None,
|
||||||
state: dict | list | None = None,
|
state: dict | list | None = None,
|
||||||
|
output_attentions = False,
|
||||||
):
|
):
|
||||||
x_list = self.inputs_to_embeddings( inputs, quant_levels )
|
x_list = self.inputs_to_embeddings( inputs, quant_levels )
|
||||||
x, m = list_to_tensor(x_list)
|
x, m = list_to_tensor(x_list)
|
||||||
|
@ -1420,32 +1432,36 @@ class Base(nn.Module):
|
||||||
# needs to be done here as we still have our raw inputs
|
# needs to be done here as we still have our raw inputs
|
||||||
position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None
|
position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None
|
||||||
|
|
||||||
x, state, aux_loss = self._forward(
|
output = self._forward(
|
||||||
inputs=x,
|
inputs=x,
|
||||||
mask=m,
|
mask=m,
|
||||||
state=state,
|
state=state,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
output_attentions = output_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logits = output.logits
|
||||||
|
|
||||||
# to-do: piece-wise classification, now that there's a head for text
|
# to-do: piece-wise classification, now that there's a head for text
|
||||||
# although again, one single monolithic head would be preferable instead......
|
# although again, one single monolithic head would be preferable instead......
|
||||||
if self.classifiers is not None:
|
if self.classifiers is not None:
|
||||||
special_tasks = [ "len", "stt" ]
|
special_tasks = [ "len", "stt" ]
|
||||||
classifier_quant_levels = [ -1 if inputs[i][0][-1] in special_tasks else l for i, l in enumerate( quant_levels ) ]
|
classifier_quant_levels = [ -1 if inputs[i][0][-1] in special_tasks else l for i, l in enumerate( quant_levels ) ]
|
||||||
x = self.classifiers(x, levels = classifier_quant_levels) * m
|
logits = self.classifiers(logits, levels = classifier_quant_levels) * m
|
||||||
|
|
||||||
# Remove padding
|
# Remove padding
|
||||||
logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ]
|
logits = [ hi[:li] for hi, li in zip(logits, map(len, x_list)) ]
|
||||||
|
|
||||||
# compute loss if the target is given
|
# compute loss if the target is given
|
||||||
if training:
|
if training:
|
||||||
self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
|
self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
|
||||||
|
|
||||||
# include any additional losses (for example: MoE router)
|
# include any additional losses (for example: MoE router)
|
||||||
if aux_loss is not None:
|
if output.aux_loss is not None:
|
||||||
self.loss["aux_loss"] = aux_loss
|
self.loss["aux_loss"] = output.aux_loss
|
||||||
|
|
||||||
return (logits, state) if state is not None else logits
|
# rewrap, because we're modifying the logits here
|
||||||
|
return Logits(logits, output.state, output.aux_loss, output.attentions)
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
|
@ -1470,10 +1486,15 @@ class Base(nn.Module):
|
||||||
dry_multiplier=0.0,
|
dry_multiplier=0.0,
|
||||||
dry_base=1.75,
|
dry_base=1.75,
|
||||||
dry_allowed_length=2,
|
dry_allowed_length=2,
|
||||||
|
# other
|
||||||
|
attentions=None,
|
||||||
):
|
):
|
||||||
if min_temperature < 0:
|
if min_temperature < 0:
|
||||||
min_temperature = temperature
|
min_temperature = temperature
|
||||||
|
|
||||||
|
scores = None
|
||||||
|
entropy = None
|
||||||
|
|
||||||
# (NAR) return the entire generated response
|
# (NAR) return the entire generated response
|
||||||
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
|
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
|
||||||
if quant_levels is not None: # and "nar" in self.capabilities: # for when I get around to coping about dropping the NAR entirely
|
if quant_levels is not None: # and "nar" in self.capabilities: # for when I get around to coping about dropping the NAR entirely
|
||||||
|
@ -1482,9 +1503,114 @@ class Base(nn.Module):
|
||||||
elif self.causal:
|
elif self.causal:
|
||||||
logits = [ logit[-self.causal_size:] for logit in logits ]
|
logits = [ logit[-self.causal_size:] for logit in logits ]
|
||||||
|
|
||||||
# this might actually slow things down a bit slightly-er?
|
# calculate entropies
|
||||||
#logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
# I would love to shove it in samplers.py but we modify our sampler settings
|
||||||
|
if attentions is not None:
|
||||||
|
entropy = [ calculate_entropix_metrics( logit, attn ) for logit, attn in zip(logits, attentions) ]
|
||||||
|
|
||||||
|
# this might actually slow things down a bit slightly-er?
|
||||||
|
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
||||||
|
|
||||||
|
# to-do: not make it hardcoded to bsz=1
|
||||||
|
metrics = entropy[0]
|
||||||
|
logit = logits[0]
|
||||||
|
|
||||||
|
ent, vent = metrics["logits_entropy"], metrics["logits_varentropy"]
|
||||||
|
attn_ent, attn_vent = metrics["attn_entropy"], metrics["attn_varentropy"]
|
||||||
|
agreement = metrics["agreement"]
|
||||||
|
interaction_strength = metrics["interaction_strength"]
|
||||||
|
|
||||||
|
# adjust sample settings
|
||||||
|
cfg = EntropixSamplerConfig()
|
||||||
|
|
||||||
|
# Low Entropy, Low Varentropy: "flowing with unspoken intent"
|
||||||
|
if ent < cfg.low_ent_thresh and vent < cfg.low_vent_thresh:
|
||||||
|
entropy[0]["action"] = 0
|
||||||
|
temperature *= 0
|
||||||
|
# High Entropy, Low Varentropy: "treading carefully, asking clarifying questions"
|
||||||
|
elif ent > cfg.high_ent_thresh and vent < cfg.low_vent_thresh:
|
||||||
|
entropy[0]["action"] = 1
|
||||||
|
# sample with slightly higher temperature
|
||||||
|
temperature *= cfg.helv_attn_ent_offset + cfg.helv_attn_ent_coef * attn_ent # Increase temperature based on attention entropy
|
||||||
|
# Low Entropy, High Varentropy: "exploring forks in the path"
|
||||||
|
elif ent < cfg.high_ent_thresh and vent > cfg.high_vent_thresh:
|
||||||
|
entropy[0]["action"] = 2
|
||||||
|
temperature *= cfg.lehv_interaction_strength_offset + cfg.lehv_interaction_strength_coef * interaction_strength # Increase temperature based on interaction strength
|
||||||
|
top_k = max(5, int(top_k * (1 + 0.5 * (1 - agreement)))) # Increase top_k when agreement is low
|
||||||
|
# High Entropy, High Varentropy: "resampling in the mist"
|
||||||
|
elif ent > cfg.med_ent_thresh and vent > cfg.high_vent_thresh:
|
||||||
|
entropy[0]["action"] = 3
|
||||||
|
# Use high temperature and adjusted top_p based on attention metrics
|
||||||
|
temperature *= cfg.hehv_attn_vent_offset + cfg.hehv_attn_vent_coef * attn_vent # Increase temperature based on attention varentropy
|
||||||
|
top_p = max(0.5, top_p - cfg.hehv_attn_ent_coef * attn_ent) # Decrease top_p when attention entropy is high
|
||||||
|
# Middle ground: use adaptive sampling
|
||||||
|
else:
|
||||||
|
entropy[0]["action"] = 4
|
||||||
|
log_softmax = torch.nn.functional.log_softmax(logit)
|
||||||
|
logits_uncertainty = ent + vent
|
||||||
|
attn_uncertainty = attn_ent + attn_vent
|
||||||
|
|
||||||
|
temperature = temperature * float(1 + cfg.ada_temp_logits * logits_uncertainty + cfg.ada_temp_attn * attn_uncertainty - cfg.ada_temp_agree * agreement)
|
||||||
|
top_p = torch.clip(top_p * (1 + cfg.ada_top_p * attn_vent), min=0.1, max=1.0).item()
|
||||||
|
top_k = int(torch.clip(
|
||||||
|
torch.round(top_k * (1 + cfg.ada_top_k_int * interaction_strength - cfg.ada_top_k_agree * agreement)),
|
||||||
|
min=cfg.top_k_min,
|
||||||
|
max=cfg.top_k_max
|
||||||
|
))
|
||||||
|
min_p = torch.clip(cfg.min_p * (1 - cfg.ada_min_p * logits_uncertainty), 0.01, 0.5)
|
||||||
|
|
||||||
|
def _sample( logits ):
|
||||||
|
# perform repetition penalizing
|
||||||
|
if "len" not in self.capabilities and prev_list is not None and repetition_penalty != 1.0:
|
||||||
|
# to-do: figure out a faster way to handle tolist()
|
||||||
|
logits = [ reptition_penalize(logit, previous=prevs[:, -1].tolist() if prevs.dim() > 1 else prevs.tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
|
||||||
|
|
||||||
|
# (AR) perform length penalizing
|
||||||
|
if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0:
|
||||||
|
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ]
|
||||||
|
|
||||||
|
# perform top_k/top_p filtering of our logits
|
||||||
|
if top_k > 0 or top_p < 1.0:
|
||||||
|
logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ]
|
||||||
|
|
||||||
|
# trigger dynamic temperature sampling if the minimum temperature is not the same as the sampling temperature
|
||||||
|
# epsilon float comparison because I don't trust Python
|
||||||
|
if abs(temperature - min_temperature) >= 0.001:
|
||||||
|
logits = [ dynamic_temperature(logit, temperature=temperature, min_temperature=min_temperature) for logit in logits ]
|
||||||
|
else:
|
||||||
|
logits = [ logit / temperature for logit in logits ]
|
||||||
|
|
||||||
|
# do DRY sampling
|
||||||
|
if dry_multiplier > 0.0:
|
||||||
|
logits = [ dry_sampling(logit, previous=resps[:, -1].tolist(), factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, resps in zip( logits, prev_list ) ]
|
||||||
|
|
||||||
|
return [ Categorical(logits=logit).sample() for logit in logits ]
|
||||||
|
|
||||||
|
samples = [ _sample([ logit.clone() for logit in logits ]) for _ in range(cfg.n_adaptive_samples) ]
|
||||||
|
|
||||||
|
def score_sample(sample):
|
||||||
|
one_hot = torch.nn.functional.one_hot(sample[0], logit.shape[-1])
|
||||||
|
log_prob = torch.sum(log_softmax * one_hot)
|
||||||
|
|
||||||
|
confidence_score = (
|
||||||
|
(1 - ent) * cfg.ada_score_logits_ent +
|
||||||
|
(1 - attn_ent) * cfg.ada_score_attn_ent +
|
||||||
|
(1 - vent) * cfg.ada_score_logits_vent +
|
||||||
|
(1 - attn_vent) * cfg.ada_score_attn_vent +
|
||||||
|
agreement * cfg.ada_score_agree +
|
||||||
|
interaction_strength * cfg.ada_score_int
|
||||||
|
)
|
||||||
|
return log_prob + confidence_score
|
||||||
|
|
||||||
|
sample_scores = [ score_sample(sample) for sample in samples ]
|
||||||
|
best_sample_idx = torch.argmax(torch.asarray(sample_scores))
|
||||||
|
|
||||||
|
res = samples[best_sample_idx]
|
||||||
|
scores = sample_scores
|
||||||
|
return Sampled(res, scores, entropy)
|
||||||
|
|
||||||
|
temperature = min(1.5, float(temperature))
|
||||||
|
|
||||||
# (NAR) disable stop token
|
# (NAR) disable stop token
|
||||||
if quant_levels is not None and "ar" in self.capabilities:
|
if quant_levels is not None and "ar" in self.capabilities:
|
||||||
logits = [ ban_tokens(logit, tokens=[self.stop_token]) for logit, l in zip( logits, map(len, prev_list) ) ]
|
logits = [ ban_tokens(logit, tokens=[self.stop_token]) for logit, l in zip( logits, map(len, prev_list) ) ]
|
||||||
|
@ -1494,7 +1620,9 @@ class Base(nn.Module):
|
||||||
|
|
||||||
# argmax instead
|
# argmax instead
|
||||||
if temperature <= 0.0:
|
if temperature <= 0.0:
|
||||||
return [ logit.argmax(dim=1) for logit in logits ]
|
res = [ logit.argmax(dim=1) for logit in logits ]
|
||||||
|
scores = None
|
||||||
|
return Sampled(res, scores, entropy)
|
||||||
|
|
||||||
# perform repetition penalizing
|
# perform repetition penalizing
|
||||||
if "len" not in self.capabilities and prev_list is not None and repetition_penalty != 1.0:
|
if "len" not in self.capabilities and prev_list is not None and repetition_penalty != 1.0:
|
||||||
|
@ -1524,17 +1652,18 @@ class Base(nn.Module):
|
||||||
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work
|
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work
|
||||||
if mirostat is not None:
|
if mirostat is not None:
|
||||||
# mirostat sampling
|
# mirostat sampling
|
||||||
return [ mirostat_sample(logit, state=state) for logit, state in zip(logits, mirostat) ]
|
scores = [ mirostat_sample(logit, state=state) for logit, state in zip(logits, mirostat) ]
|
||||||
|
res = [ state["token"] for state in scores ]
|
||||||
# do beam search (naive implementation)
|
# do beam search (naive implementation)
|
||||||
# picks the top-k across all batches, and re-batches those resultant tokens
|
# picks the top-k across all batches, and re-batches those resultant tokens
|
||||||
# returns the logit scores as well to be P-concatted with the previous scores
|
# returns the logit scores as well to be P-concatted with the previous scores
|
||||||
# to-do: not naively implement beam searching
|
# to-do: not naively implement beam searching
|
||||||
if beam_width > 1:
|
elif beam_width > 1:
|
||||||
candidates = top_k_logits_list( logits, beam_width )
|
candidates = top_k_logits_list( logits, beam_width )
|
||||||
res = [ torch.tensor(token, dtype=torch.int16).unsqueeze(dim=-1) for batch, token in candidates ]
|
res = [ torch.tensor(token, dtype=torch.int16).unsqueeze(dim=-1) for batch, token in candidates ]
|
||||||
scores = [ logits[batch].flatten()[token] for batch, token in candidates ]
|
scores = [ logits[batch].flatten()[token] for batch, token in candidates ]
|
||||||
return res, scores
|
# basic sampling
|
||||||
|
else:
|
||||||
|
res = [ Categorical(logits=logit).sample() for logit in logits ]
|
||||||
|
|
||||||
# and sample
|
return Sampled(res, scores, entropy)
|
||||||
return [ Categorical(logits=logit).sample() for logit in logits ]
|
|
|
@ -172,16 +172,17 @@ class NAR(Base):
|
||||||
quant_levels=quant_levels,
|
quant_levels=quant_levels,
|
||||||
)
|
)
|
||||||
|
|
||||||
logits = super().forward(
|
output = super().forward(
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
quant_levels=quant_levels,
|
quant_levels=quant_levels,
|
||||||
)
|
)
|
||||||
|
logits = output.logits
|
||||||
|
|
||||||
"""
|
"""
|
||||||
resps_list = [ logit[-l:].argmax(dim=1) for logit, l in zip(logits, len_list) ]
|
resps_list = [ logit[-l:].argmax(dim=1) for logit, l in zip(logits, len_list) ]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
resps_list = super().sample(
|
sampled = super().sample(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
prev_list=prev_list,
|
prev_list=prev_list,
|
||||||
quant_levels=quant_levels,
|
quant_levels=quant_levels,
|
||||||
|
@ -196,6 +197,7 @@ class NAR(Base):
|
||||||
#beam_width=sampling_beam_width,
|
#beam_width=sampling_beam_width,
|
||||||
#mirostat=mirostat,
|
#mirostat=mirostat,
|
||||||
)
|
)
|
||||||
|
resps_list = sampled[0]
|
||||||
|
|
||||||
if n == 0:
|
if n == 0:
|
||||||
prev_list = [ r.unsqueeze(-1).to(device) for r in resps_list ]
|
prev_list = [ r.unsqueeze(-1).to(device) for r in resps_list ]
|
||||||
|
@ -225,9 +227,10 @@ class NAR(Base):
|
||||||
quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ]
|
quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ]
|
||||||
)
|
)
|
||||||
|
|
||||||
logits = super().forward(
|
output = super().forward(
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
)
|
)
|
||||||
|
logits = output.logits
|
||||||
|
|
||||||
r = [ logit[-1:].argmax(dim=1) for logit in logits ]
|
r = [ logit[-1:].argmax(dim=1) for logit in logits ]
|
||||||
# sanitize
|
# sanitize
|
||||||
|
|
|
@ -93,6 +93,29 @@ def plot(paths, args):
|
||||||
#bbox_to_anchor=(1.04, 0.5),
|
#bbox_to_anchor=(1.04, 0.5),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def plot_entropies( entropies ):
|
||||||
|
"""
|
||||||
|
fig = plt.figure()
|
||||||
|
fig.set_figwidth( 16 * len(entropies) // cfg.dataset.frames_per_second )
|
||||||
|
"""
|
||||||
|
|
||||||
|
data = {}
|
||||||
|
|
||||||
|
for key in entropies[0][0].keys():
|
||||||
|
data[key] = [ e[0][key].item() if hasattr( e[0][key], "item" ) else e[0][key] for e in entropies ]
|
||||||
|
|
||||||
|
df = pd.DataFrame(data)
|
||||||
|
df.plot()
|
||||||
|
|
||||||
|
plt.gca().legend(
|
||||||
|
#loc="center left",
|
||||||
|
fancybox=True,
|
||||||
|
shadow=True,
|
||||||
|
#bbox_to_anchor=(1.04, 0.5),
|
||||||
|
)
|
||||||
|
|
||||||
|
out_path = cfg.rel_path / "metrics.png"
|
||||||
|
plt.savefig(out_path, bbox_inches="tight")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
|
@ -5,6 +5,8 @@ import numpy as np
|
||||||
|
|
||||||
from torch import Tensor, einsum, nn
|
from torch import Tensor, einsum, nn
|
||||||
|
|
||||||
|
from dataclasses import asdict, dataclass, field
|
||||||
|
|
||||||
# Simple filter to modify a token's probability if it shows up in the past
|
# Simple filter to modify a token's probability if it shows up in the past
|
||||||
# `one_time` will only apply the penalty once
|
# `one_time` will only apply the penalty once
|
||||||
# `decay` is a factor that will exponentially apply to how far away it is
|
# `decay` is a factor that will exponentially apply to how far away it is
|
||||||
|
@ -201,4 +203,86 @@ def dry_sampling( logits, previous=None, factor=0.0, base=1.75, allowed_length=2
|
||||||
break
|
break
|
||||||
logits[:, token] -= factor * base ** (length - allowed_length)
|
logits[:, token] -= factor * base ** (length - allowed_length)
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
LN_2 = 0.69314718056 # ln(2) = 1.0 / LOG2_E
|
||||||
|
|
||||||
|
# Grabbed from https://github.com/xjdr-alt/entropix/blob/main/entropix/sampler.py
|
||||||
|
# Right now I only care about quantifying these two, I'll figure out how to best apply this to the model
|
||||||
|
def calculate_entropix_metrics( logits, attention_scores=None, dim=-1 ):
|
||||||
|
"""Calculate the entropy and varentropy of the probability distribution using logsoftmax."""
|
||||||
|
log_probs = torch.nn.functional.log_softmax(logits, dim=dim)
|
||||||
|
probs = torch.exp(log_probs)
|
||||||
|
entropy = -torch.sum(probs * log_probs, dim=dim) / LN_2 # Convert to base-2
|
||||||
|
varentropy = torch.sum(probs * (log_probs / LN_2 + entropy[..., None])**2, dim=dim)
|
||||||
|
|
||||||
|
if attention_scores is None:
|
||||||
|
return {
|
||||||
|
"logits_entropy": torch.mean(entropy).item(),
|
||||||
|
"logits_varentropy": torch.mean(varentropy).item(),
|
||||||
|
}
|
||||||
|
|
||||||
|
attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1)
|
||||||
|
attn_entropy = -torch.sum(attention_probs * torch.log2(torch.clip(attention_probs, 1e-10, 1.0)), dim=-1)
|
||||||
|
attn_varentropy = torch.var(attn_entropy, dim=1)
|
||||||
|
|
||||||
|
mean_attention = torch.mean(attention_probs, dim=1)
|
||||||
|
agreement = torch.mean(torch.abs(attention_probs - mean_attention[:, None, :]), dim=(1, 2))
|
||||||
|
|
||||||
|
interaction_strength = torch.mean(torch.abs(attention_scores), dim=(1, 2, 3))
|
||||||
|
return {
|
||||||
|
"logits_entropy": torch.mean(entropy),
|
||||||
|
"logits_varentropy": torch.mean(varentropy),
|
||||||
|
"attn_entropy": torch.mean(attn_entropy),
|
||||||
|
"attn_varentropy": torch.mean(attn_varentropy),
|
||||||
|
"agreement": torch.mean(agreement),
|
||||||
|
"interaction_strength": torch.mean(torch.abs(attention_scores), dim=(1, 2, 3)),
|
||||||
|
}
|
||||||
|
|
||||||
|
# to-do: play around with these values
|
||||||
|
@dataclass()
|
||||||
|
class EntropixSamplerConfig:
|
||||||
|
temp: float = 0.999
|
||||||
|
top_p: float = 0.90
|
||||||
|
top_k: int = 32
|
||||||
|
min_p: float = 0.01 # was 0.03 # Turn this down to 0.01 to reduce the shoggoth
|
||||||
|
|
||||||
|
low_ent_thresh: float = 0.1
|
||||||
|
low_vent_thresh: float = 0.1
|
||||||
|
med_ent_thresh: float = 3.0
|
||||||
|
high_ent_thresh: float = 5.0
|
||||||
|
high_vent_thresh: float = 5.0
|
||||||
|
|
||||||
|
# TODO this is a bit of a nasty mess, but also makes all the hyperparameters visible
|
||||||
|
helv_attn_ent_offset: float = 1.3
|
||||||
|
helv_attn_ent_coef: float = 0.2
|
||||||
|
|
||||||
|
lehv_interaction_strength_offset: float = 1.2
|
||||||
|
lehv_interaction_strength_coef: float = 0.3
|
||||||
|
|
||||||
|
hehv_attn_ent_coef: float = 0.2
|
||||||
|
hehv_attn_vent_offset: float = 2.0
|
||||||
|
hehv_attn_vent_coef: float = 0.5
|
||||||
|
|
||||||
|
# TODO not convinced this should
|
||||||
|
n_adaptive_samples: int = 5
|
||||||
|
|
||||||
|
# Adaptive sampling parameters
|
||||||
|
ada_temp_logits: float = 0.3
|
||||||
|
ada_temp_attn: float = 0.2
|
||||||
|
ada_temp_agree: float = 0.2
|
||||||
|
ada_top_p: float = 0.1
|
||||||
|
ada_top_k_int: float = 0.3
|
||||||
|
ada_top_k_agree: float = 0.2
|
||||||
|
ada_min_p: float = 0.5
|
||||||
|
ada_score_logits_ent: float = 0.1
|
||||||
|
ada_score_attn_ent: float = 0.2
|
||||||
|
ada_score_logits_vent: float = 0.3
|
||||||
|
ada_score_attn_vent: float = 0.4
|
||||||
|
ada_score_agree: float = 0.5
|
||||||
|
ada_score_int: float = 0.6
|
||||||
|
|
||||||
|
# extra stuff
|
||||||
|
top_k_min: int = 32
|
||||||
|
top_k_max: int = 128
|
Loading…
Reference in New Issue
Block a user