diff --git a/vall_e/config.py b/vall_e/config.py index cd60cdd..db10966 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -257,8 +257,12 @@ class ModelExperimentalSettings: len_train_p: float = 0.05 # odds of injecting a "len" task within the model for NAR-len # to-to: just incorporate this as a task instead - - layerskip: bool = False + + layerskip: bool = False # layerskip compatible model (or training for) + #layerskip_rvq_levels: list = field(default_factory=lambda: []) # RVQ levels to train / inference layerskip for (to-do: implement, see if it matters) + layerskip_r: int = 2 # number of layers to factor into early-exit loss calc + layerskip_p_max: float = 0.1 # maximum probabilty to dropout the last layer, used for calculating layer dropout probabilities + layerskip_e_scale: float = 0.2 # early-exit loss scalar value # I really need to clean this up @dataclass() @@ -454,30 +458,32 @@ class Evaluation: # necessary in order to make it not confusing with requiring not-directyl exposed arguments passed to the model @cached_property def ar_kwargs( self ): + kwargs = {} | self.kwargs return dict( - max_steps=self.kwargs["max_ar_steps"], - sampling_temperature=self.kwargs["ar_temp"], - sampling_min_temperature=self.kwargs["min_ar_temp"], - sampling_top_p=self.kwargs["top_p"], sampling_top_k=self.kwargs["top_k"], sampling_min_p=self.kwargs["min_p"], - sampling_repetition_penalty=self.kwargs["repetition_penalty"], sampling_repetition_penalty_decay=self.kwargs["repetition_penalty_decay"], - sampling_length_penalty=self.kwargs["length_penalty"], - sampling_beam_width=self.kwargs["beam_width"], - sampling_mirostat_tau=self.kwargs["mirostat_tau"], - sampling_mirostat_eta=self.kwargs["mirostat_eta"], - sampling_dry_multiplier=self.kwargs["dry_multiplier"], - sampling_dry_base=self.kwargs["dry_base"], - sampling_dry_allowed_length=self.kwargs["dry_allowed_length"], - sampling_entropix=self.kwargs["entropix_sampling"], + max_steps=kwargs.pop("max_ar_steps", 500), + sampling_temperature=kwargs.pop("ar_temp", 0.5), + sampling_min_temperature=kwargs.pop("min_ar_temp", -1), + sampling_top_p=kwargs.pop("top_p", 1.0), sampling_top_k=kwargs.pop("top_k", 0), sampling_min_p=kwargs.pop("min_p", 0.0), + sampling_repetition_penalty=kwargs.pop("repetition_penalty", 1.125), sampling_repetition_penalty_decay=kwargs.pop("repetition_penalty_decay", 0), + sampling_length_penalty=kwargs.pop("length_penalty", 0), + sampling_beam_width=kwargs.pop("beam_width", 0), + sampling_mirostat_tau=kwargs.pop("mirostat_tau", 0), + sampling_mirostat_eta=kwargs.pop("mirostat_eta", 0), + sampling_dry_multiplier=kwargs.pop("dry_multiplier", 0), + sampling_dry_base=kwargs.pop("dry_base", 0), + sampling_dry_allowed_length=kwargs.pop("dry_allowed_length", 0), + sampling_entropix=kwargs.pop("entropix_sampling", False), ) @cached_property def nar_kwargs( self ): + kwargs = {} | self.kwargs return dict( - max_levels=self.kwargs["max_nar_levels"], - sampling_temperature=self.kwargs["nar_temp"], - sampling_min_temperature=self.kwargs["min_nar_temp"], - sampling_top_p=self.kwargs["top_p"], sampling_top_k=self.kwargs["top_k"], sampling_min_p=self.kwargs["min_p"], - sampling_repetition_penalty=self.kwargs["repetition_penalty"], sampling_repetition_penalty_decay=self.kwargs["repetition_penalty_decay"], + max_levels=kwargs.pop("max_nar_levels", 0), + sampling_temperature=kwargs.pop("nar_temp", 0.0), + sampling_min_temperature=kwargs.pop("min_nar_temp", -1), + sampling_top_p=kwargs.pop("top_p", 1.0), sampling_top_k=kwargs.pop("top_k", 0.0), sampling_min_p=kwargs.pop("min_p", 0.0), + sampling_repetition_penalty=kwargs.pop("repetition_penalty", 1.0), sampling_repetition_penalty_decay=kwargs.pop("repetition_penalty_decay", 0.0), ) @dataclass() diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index 3da812d..343f6dc 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -308,6 +308,7 @@ class LlamaModel_Adapted(LlamaModel): def __init__(self, *args, **kwargs): self.layer_dropout_p = kwargs.pop("layer_dropout_p", 0.1) self.early_exit_scale = kwargs.pop("early_exit_scale", 0.1) + self.early_exit_r = kwargs.pop("early_exit_r", 2) super().__init__(*args, **kwargs) @@ -321,16 +322,16 @@ class LlamaModel_Adapted(LlamaModel): P = D * self.layer_dropout_p return random.random() < P - # to-do: properly implement this per the paper - # there doesn't seem /too/ bad of a performance hit, but the paper mentions it affecting accuracy of the last layer if all layers had early exit - def cirriculum( self, l, t=None, R=2 ): + def cirriculum( self, l, t=None ): + # no timestep data passed, just treat all layers as enabled + # there doesn't seem /too/ bad of a performance hit, but the paper mentions it affecting accuracy of the last layer if all layers had early exit if t is None: return 1 # YUCK # this guarantees at least R layers are active at all intervals, which is important because this gives a division by zero otherwise - for i in range(R): - if l == ((t % self.layers_n) + i * (self.layers_n // R)) % self.layers_n: + for i in range(self.early_exit_r): + if l == ((t % self.layers_n) + i * (self.layers_n // self.early_exit_r)) % self.layers_n: return 1 return 0 @@ -357,6 +358,7 @@ class LlamaModel_Adapted(LlamaModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + early_exit_layer: Optional[int] = -1, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 09cea8c..b712b24 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -443,7 +443,11 @@ class Base(nn.Module): audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else "" unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True interleave = self.config.experimental.interleave if self.config is not None else False + layerskip = self.config.experimental.layerskip if self.config is not None else False + layerskip_r = self.config.experimental.layerskip_r if self.config is not None else 2 + layerskip_p_max = self.config.experimental.layerskip_p_max if self.config is not None else 0.1 + layerskip_e_scale = self.config.experimental.layerskip_e_scale if self.config is not None else 0.1 n_tasks = self.config.tasks if self.config is not None else 8 n_langs = self.config.langs if self.config is not None else 2 @@ -648,6 +652,11 @@ class Base(nn.Module): if attention_backend not in HF_ATTENTIONS: self.model = ml.replace_attention( self.model, klass=MixtralAttention_Adapted, target=MixtralAttention, mode=attention_backend ) + if self.layerskip: + self.model.layer_dropout_p = layerskip_p_max + self.model.early_exit_scale = layerskip_e_scale + self.model.early_exit_r = layerskip_r + if self.gradient_checkpointing and not self.model.gradient_checkpointing: self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( use_reentrant=False @@ -917,7 +926,8 @@ class Base(nn.Module): # process it into a format that I like if output_hidden_states: - hidden_states = [ x if i == self.n_layers - 1 else self.model.norm(output.hidden_states[i]) for i in range( self.n_layers ) ] + # hidden_states is actually layers + 1, as hidden_states[0] == embedding........... + hidden_states = [ x if i == self.n_layers else self.model.norm(output.hidden_states[i]) for i in range( 1, self.n_layers + 1 ) ] # output projection layer with masking if self.classifier is not None: