off-by-one...

This commit is contained in:
mrq 2024-10-31 13:24:48 -05:00
parent b63293cbbe
commit 76ebef45dc
3 changed files with 44 additions and 26 deletions

View File

@ -257,8 +257,12 @@ class ModelExperimentalSettings:
len_train_p: float = 0.05 # odds of injecting a "len" task within the model for NAR-len
# to-to: just incorporate this as a task instead
layerskip: bool = False
layerskip: bool = False # layerskip compatible model (or training for)
#layerskip_rvq_levels: list = field(default_factory=lambda: []) # RVQ levels to train / inference layerskip for (to-do: implement, see if it matters)
layerskip_r: int = 2 # number of layers to factor into early-exit loss calc
layerskip_p_max: float = 0.1 # maximum probabilty to dropout the last layer, used for calculating layer dropout probabilities
layerskip_e_scale: float = 0.2 # early-exit loss scalar value
# I really need to clean this up
@dataclass()
@ -454,30 +458,32 @@ class Evaluation:
# necessary in order to make it not confusing with requiring not-directyl exposed arguments passed to the model
@cached_property
def ar_kwargs( self ):
kwargs = {} | self.kwargs
return dict(
max_steps=self.kwargs["max_ar_steps"],
sampling_temperature=self.kwargs["ar_temp"],
sampling_min_temperature=self.kwargs["min_ar_temp"],
sampling_top_p=self.kwargs["top_p"], sampling_top_k=self.kwargs["top_k"], sampling_min_p=self.kwargs["min_p"],
sampling_repetition_penalty=self.kwargs["repetition_penalty"], sampling_repetition_penalty_decay=self.kwargs["repetition_penalty_decay"],
sampling_length_penalty=self.kwargs["length_penalty"],
sampling_beam_width=self.kwargs["beam_width"],
sampling_mirostat_tau=self.kwargs["mirostat_tau"],
sampling_mirostat_eta=self.kwargs["mirostat_eta"],
sampling_dry_multiplier=self.kwargs["dry_multiplier"],
sampling_dry_base=self.kwargs["dry_base"],
sampling_dry_allowed_length=self.kwargs["dry_allowed_length"],
sampling_entropix=self.kwargs["entropix_sampling"],
max_steps=kwargs.pop("max_ar_steps", 500),
sampling_temperature=kwargs.pop("ar_temp", 0.5),
sampling_min_temperature=kwargs.pop("min_ar_temp", -1),
sampling_top_p=kwargs.pop("top_p", 1.0), sampling_top_k=kwargs.pop("top_k", 0), sampling_min_p=kwargs.pop("min_p", 0.0),
sampling_repetition_penalty=kwargs.pop("repetition_penalty", 1.125), sampling_repetition_penalty_decay=kwargs.pop("repetition_penalty_decay", 0),
sampling_length_penalty=kwargs.pop("length_penalty", 0),
sampling_beam_width=kwargs.pop("beam_width", 0),
sampling_mirostat_tau=kwargs.pop("mirostat_tau", 0),
sampling_mirostat_eta=kwargs.pop("mirostat_eta", 0),
sampling_dry_multiplier=kwargs.pop("dry_multiplier", 0),
sampling_dry_base=kwargs.pop("dry_base", 0),
sampling_dry_allowed_length=kwargs.pop("dry_allowed_length", 0),
sampling_entropix=kwargs.pop("entropix_sampling", False),
)
@cached_property
def nar_kwargs( self ):
kwargs = {} | self.kwargs
return dict(
max_levels=self.kwargs["max_nar_levels"],
sampling_temperature=self.kwargs["nar_temp"],
sampling_min_temperature=self.kwargs["min_nar_temp"],
sampling_top_p=self.kwargs["top_p"], sampling_top_k=self.kwargs["top_k"], sampling_min_p=self.kwargs["min_p"],
sampling_repetition_penalty=self.kwargs["repetition_penalty"], sampling_repetition_penalty_decay=self.kwargs["repetition_penalty_decay"],
max_levels=kwargs.pop("max_nar_levels", 0),
sampling_temperature=kwargs.pop("nar_temp", 0.0),
sampling_min_temperature=kwargs.pop("min_nar_temp", -1),
sampling_top_p=kwargs.pop("top_p", 1.0), sampling_top_k=kwargs.pop("top_k", 0.0), sampling_min_p=kwargs.pop("min_p", 0.0),
sampling_repetition_penalty=kwargs.pop("repetition_penalty", 1.0), sampling_repetition_penalty_decay=kwargs.pop("repetition_penalty_decay", 0.0),
)
@dataclass()

View File

@ -308,6 +308,7 @@ class LlamaModel_Adapted(LlamaModel):
def __init__(self, *args, **kwargs):
self.layer_dropout_p = kwargs.pop("layer_dropout_p", 0.1)
self.early_exit_scale = kwargs.pop("early_exit_scale", 0.1)
self.early_exit_r = kwargs.pop("early_exit_r", 2)
super().__init__(*args, **kwargs)
@ -321,16 +322,16 @@ class LlamaModel_Adapted(LlamaModel):
P = D * self.layer_dropout_p
return random.random() < P
# to-do: properly implement this per the paper
# there doesn't seem /too/ bad of a performance hit, but the paper mentions it affecting accuracy of the last layer if all layers had early exit
def cirriculum( self, l, t=None, R=2 ):
def cirriculum( self, l, t=None ):
# no timestep data passed, just treat all layers as enabled
# there doesn't seem /too/ bad of a performance hit, but the paper mentions it affecting accuracy of the last layer if all layers had early exit
if t is None:
return 1
# YUCK
# this guarantees at least R layers are active at all intervals, which is important because this gives a division by zero otherwise
for i in range(R):
if l == ((t % self.layers_n) + i * (self.layers_n // R)) % self.layers_n:
for i in range(self.early_exit_r):
if l == ((t % self.layers_n) + i * (self.layers_n // self.early_exit_r)) % self.layers_n:
return 1
return 0
@ -357,6 +358,7 @@ class LlamaModel_Adapted(LlamaModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
early_exit_layer: Optional[int] = -1,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (

View File

@ -443,7 +443,11 @@ class Base(nn.Module):
audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else ""
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
interleave = self.config.experimental.interleave if self.config is not None else False
layerskip = self.config.experimental.layerskip if self.config is not None else False
layerskip_r = self.config.experimental.layerskip_r if self.config is not None else 2
layerskip_p_max = self.config.experimental.layerskip_p_max if self.config is not None else 0.1
layerskip_e_scale = self.config.experimental.layerskip_e_scale if self.config is not None else 0.1
n_tasks = self.config.tasks if self.config is not None else 8
n_langs = self.config.langs if self.config is not None else 2
@ -648,6 +652,11 @@ class Base(nn.Module):
if attention_backend not in HF_ATTENTIONS:
self.model = ml.replace_attention( self.model, klass=MixtralAttention_Adapted, target=MixtralAttention, mode=attention_backend )
if self.layerskip:
self.model.layer_dropout_p = layerskip_p_max
self.model.early_exit_scale = layerskip_e_scale
self.model.early_exit_r = layerskip_r
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False
@ -917,7 +926,8 @@ class Base(nn.Module):
# process it into a format that I like
if output_hidden_states:
hidden_states = [ x if i == self.n_layers - 1 else self.model.norm(output.hidden_states[i]) for i in range( self.n_layers ) ]
# hidden_states is actually layers + 1, as hidden_states[0] == embedding...........
hidden_states = [ x if i == self.n_layers else self.model.norm(output.hidden_states[i]) for i in range( 1, self.n_layers + 1 ) ]
# output projection layer with masking
if self.classifier is not None: