off-by-one...
This commit is contained in:
parent
b63293cbbe
commit
76ebef45dc
|
@ -257,8 +257,12 @@ class ModelExperimentalSettings:
|
|||
|
||||
len_train_p: float = 0.05 # odds of injecting a "len" task within the model for NAR-len
|
||||
# to-to: just incorporate this as a task instead
|
||||
|
||||
layerskip: bool = False
|
||||
|
||||
layerskip: bool = False # layerskip compatible model (or training for)
|
||||
#layerskip_rvq_levels: list = field(default_factory=lambda: []) # RVQ levels to train / inference layerskip for (to-do: implement, see if it matters)
|
||||
layerskip_r: int = 2 # number of layers to factor into early-exit loss calc
|
||||
layerskip_p_max: float = 0.1 # maximum probabilty to dropout the last layer, used for calculating layer dropout probabilities
|
||||
layerskip_e_scale: float = 0.2 # early-exit loss scalar value
|
||||
|
||||
# I really need to clean this up
|
||||
@dataclass()
|
||||
|
@ -454,30 +458,32 @@ class Evaluation:
|
|||
# necessary in order to make it not confusing with requiring not-directyl exposed arguments passed to the model
|
||||
@cached_property
|
||||
def ar_kwargs( self ):
|
||||
kwargs = {} | self.kwargs
|
||||
return dict(
|
||||
max_steps=self.kwargs["max_ar_steps"],
|
||||
sampling_temperature=self.kwargs["ar_temp"],
|
||||
sampling_min_temperature=self.kwargs["min_ar_temp"],
|
||||
sampling_top_p=self.kwargs["top_p"], sampling_top_k=self.kwargs["top_k"], sampling_min_p=self.kwargs["min_p"],
|
||||
sampling_repetition_penalty=self.kwargs["repetition_penalty"], sampling_repetition_penalty_decay=self.kwargs["repetition_penalty_decay"],
|
||||
sampling_length_penalty=self.kwargs["length_penalty"],
|
||||
sampling_beam_width=self.kwargs["beam_width"],
|
||||
sampling_mirostat_tau=self.kwargs["mirostat_tau"],
|
||||
sampling_mirostat_eta=self.kwargs["mirostat_eta"],
|
||||
sampling_dry_multiplier=self.kwargs["dry_multiplier"],
|
||||
sampling_dry_base=self.kwargs["dry_base"],
|
||||
sampling_dry_allowed_length=self.kwargs["dry_allowed_length"],
|
||||
sampling_entropix=self.kwargs["entropix_sampling"],
|
||||
max_steps=kwargs.pop("max_ar_steps", 500),
|
||||
sampling_temperature=kwargs.pop("ar_temp", 0.5),
|
||||
sampling_min_temperature=kwargs.pop("min_ar_temp", -1),
|
||||
sampling_top_p=kwargs.pop("top_p", 1.0), sampling_top_k=kwargs.pop("top_k", 0), sampling_min_p=kwargs.pop("min_p", 0.0),
|
||||
sampling_repetition_penalty=kwargs.pop("repetition_penalty", 1.125), sampling_repetition_penalty_decay=kwargs.pop("repetition_penalty_decay", 0),
|
||||
sampling_length_penalty=kwargs.pop("length_penalty", 0),
|
||||
sampling_beam_width=kwargs.pop("beam_width", 0),
|
||||
sampling_mirostat_tau=kwargs.pop("mirostat_tau", 0),
|
||||
sampling_mirostat_eta=kwargs.pop("mirostat_eta", 0),
|
||||
sampling_dry_multiplier=kwargs.pop("dry_multiplier", 0),
|
||||
sampling_dry_base=kwargs.pop("dry_base", 0),
|
||||
sampling_dry_allowed_length=kwargs.pop("dry_allowed_length", 0),
|
||||
sampling_entropix=kwargs.pop("entropix_sampling", False),
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def nar_kwargs( self ):
|
||||
kwargs = {} | self.kwargs
|
||||
return dict(
|
||||
max_levels=self.kwargs["max_nar_levels"],
|
||||
sampling_temperature=self.kwargs["nar_temp"],
|
||||
sampling_min_temperature=self.kwargs["min_nar_temp"],
|
||||
sampling_top_p=self.kwargs["top_p"], sampling_top_k=self.kwargs["top_k"], sampling_min_p=self.kwargs["min_p"],
|
||||
sampling_repetition_penalty=self.kwargs["repetition_penalty"], sampling_repetition_penalty_decay=self.kwargs["repetition_penalty_decay"],
|
||||
max_levels=kwargs.pop("max_nar_levels", 0),
|
||||
sampling_temperature=kwargs.pop("nar_temp", 0.0),
|
||||
sampling_min_temperature=kwargs.pop("min_nar_temp", -1),
|
||||
sampling_top_p=kwargs.pop("top_p", 1.0), sampling_top_k=kwargs.pop("top_k", 0.0), sampling_min_p=kwargs.pop("min_p", 0.0),
|
||||
sampling_repetition_penalty=kwargs.pop("repetition_penalty", 1.0), sampling_repetition_penalty_decay=kwargs.pop("repetition_penalty_decay", 0.0),
|
||||
)
|
||||
|
||||
@dataclass()
|
||||
|
|
|
@ -308,6 +308,7 @@ class LlamaModel_Adapted(LlamaModel):
|
|||
def __init__(self, *args, **kwargs):
|
||||
self.layer_dropout_p = kwargs.pop("layer_dropout_p", 0.1)
|
||||
self.early_exit_scale = kwargs.pop("early_exit_scale", 0.1)
|
||||
self.early_exit_r = kwargs.pop("early_exit_r", 2)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
@ -321,16 +322,16 @@ class LlamaModel_Adapted(LlamaModel):
|
|||
P = D * self.layer_dropout_p
|
||||
return random.random() < P
|
||||
|
||||
# to-do: properly implement this per the paper
|
||||
# there doesn't seem /too/ bad of a performance hit, but the paper mentions it affecting accuracy of the last layer if all layers had early exit
|
||||
def cirriculum( self, l, t=None, R=2 ):
|
||||
def cirriculum( self, l, t=None ):
|
||||
# no timestep data passed, just treat all layers as enabled
|
||||
# there doesn't seem /too/ bad of a performance hit, but the paper mentions it affecting accuracy of the last layer if all layers had early exit
|
||||
if t is None:
|
||||
return 1
|
||||
|
||||
# YUCK
|
||||
# this guarantees at least R layers are active at all intervals, which is important because this gives a division by zero otherwise
|
||||
for i in range(R):
|
||||
if l == ((t % self.layers_n) + i * (self.layers_n // R)) % self.layers_n:
|
||||
for i in range(self.early_exit_r):
|
||||
if l == ((t % self.layers_n) + i * (self.layers_n // self.early_exit_r)) % self.layers_n:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
@ -357,6 +358,7 @@ class LlamaModel_Adapted(LlamaModel):
|
|||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
early_exit_layer: Optional[int] = -1,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
|
|
|
@ -443,7 +443,11 @@ class Base(nn.Module):
|
|||
audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else ""
|
||||
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
|
||||
interleave = self.config.experimental.interleave if self.config is not None else False
|
||||
|
||||
layerskip = self.config.experimental.layerskip if self.config is not None else False
|
||||
layerskip_r = self.config.experimental.layerskip_r if self.config is not None else 2
|
||||
layerskip_p_max = self.config.experimental.layerskip_p_max if self.config is not None else 0.1
|
||||
layerskip_e_scale = self.config.experimental.layerskip_e_scale if self.config is not None else 0.1
|
||||
|
||||
n_tasks = self.config.tasks if self.config is not None else 8
|
||||
n_langs = self.config.langs if self.config is not None else 2
|
||||
|
@ -648,6 +652,11 @@ class Base(nn.Module):
|
|||
if attention_backend not in HF_ATTENTIONS:
|
||||
self.model = ml.replace_attention( self.model, klass=MixtralAttention_Adapted, target=MixtralAttention, mode=attention_backend )
|
||||
|
||||
if self.layerskip:
|
||||
self.model.layer_dropout_p = layerskip_p_max
|
||||
self.model.early_exit_scale = layerskip_e_scale
|
||||
self.model.early_exit_r = layerskip_r
|
||||
|
||||
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
|
||||
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||
use_reentrant=False
|
||||
|
@ -917,7 +926,8 @@ class Base(nn.Module):
|
|||
|
||||
# process it into a format that I like
|
||||
if output_hidden_states:
|
||||
hidden_states = [ x if i == self.n_layers - 1 else self.model.norm(output.hidden_states[i]) for i in range( self.n_layers ) ]
|
||||
# hidden_states is actually layers + 1, as hidden_states[0] == embedding...........
|
||||
hidden_states = [ x if i == self.n_layers else self.model.norm(output.hidden_states[i]) for i in range( 1, self.n_layers + 1 ) ]
|
||||
|
||||
# output projection layer with masking
|
||||
if self.classifier is not None:
|
||||
|
|
Loading…
Reference in New Issue
Block a user