From a22534e8f4ecd37d6bd47e33d2bc3a38e297fcca Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 30 Oct 2024 20:05:45 -0500 Subject: [PATCH] layer skip training implemented (need to gut the inferencing from the repo, and to actually see if the model can benefit from this) --- vall_e/config.py | 7 +- vall_e/demo.py | 5 +- vall_e/models/ar_nar.py | 2 +- vall_e/models/arch/__init__.py | 2 +- vall_e/models/arch/llama.py | 171 +++++++++++++++++++++++++++++++-- vall_e/models/base.py | 99 +++++++++++++++---- vall_e/models/nar.py | 4 +- 7 files changed, 251 insertions(+), 39 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index d176a11..cd60cdd 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -255,8 +255,10 @@ class ModelExperimentalSettings: # it just seems like a bitch to try and train something worthwhile with it, since there's crackles every other token # RetNet's chunked inferencing might be a better place for this - p_len_train: float = 0.05 # odds of injecting a "len" task within the model for NAR-len + len_train_p: float = 0.05 # odds of injecting a "len" task within the model for NAR-len # to-to: just incorporate this as a task instead + + layerskip: bool = False # I really need to clean this up @dataclass() @@ -870,9 +872,6 @@ class Config(BaseConfig): def format( self, training=True ): - print( self.models ) - print( self.loras ) - if isinstance(self.dataset, type): self.dataset = dict() diff --git a/vall_e/demo.py b/vall_e/demo.py index 5064065..f8e4d7f 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -43,7 +43,6 @@ def main(): parser.add_argument("--yaml", type=Path, default=None) parser.add_argument("--model", type=Path, default=None) - parser.add_argument("--lora", type=Path, default=None) parser.add_argument("--demo-dir", type=Path, default=None) parser.add_argument("--skip-existing", action="store_true") @@ -245,8 +244,8 @@ def main(): metadata = batch["metadata"] - #text = get_random_prompt() if args.random_prompts else metadata["text"] - text = get_random_prompt() if i >= (num // 2) else metadata["text"] + text = get_random_prompt() if args.random_prompts else metadata["text"] + #text = get_random_prompt() if i >= (num // 2) else metadata["text"] language = metadata["language"].lower() prompt = dir / "prompt.wav" diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 717cb43..a2f60b1 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -496,7 +496,7 @@ def example_usage(): bos_id, space_id, eos_id = cfg.tokenizer.encode( " " ) #available_tasks = cfg.dataset.tasks_list - available_tasks = ["tts", "stt"] + available_tasks = ["tts"] # , "stt"] model = AR_NAR(**kwargs).to(device) steps = 500 # 150 * len(available_tasks) # * cfg.model.experimental.causal_size diff --git a/vall_e/models/arch/__init__.py b/vall_e/models/arch/__init__.py index 187da69..4eb89d7 100755 --- a/vall_e/models/arch/__init__.py +++ b/vall_e/models/arch/__init__.py @@ -30,7 +30,7 @@ except Exception as e: pass try: - from .llama import LlamaModel, LlamaConfig, AVAILABLE_ATTENTIONS, LlamaAttention, LlamaAttention_Adapted, LlamaForCausalLM + from .llama import LlamaModel, LlamaModel_Adapted, LlamaConfig, AVAILABLE_ATTENTIONS, LlamaAttention, LlamaAttention_Adapted, LlamaForCausalLM AVAILABLE_ARCHES.append("llama") except Exception as e: ERROR_ARCHES["llama"] = e diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index 3b9bed7..81fb5bd 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -3,19 +3,24 @@ import math import torch import logging +import random -from typing import Literal, overload, Optional, Tuple +from typing import Literal, overload, Optional, Tuple, Union, List, Unpack from torch import Tensor, nn from transformers.cache_utils import Cache from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv _logger = logging.getLogger(__name__) AVAILABLE_ATTENTIONS = [] +LN_2 = 0.69314718056 + try: from transformers.utils import is_flash_attn_2_available @@ -131,11 +136,7 @@ if AVAILABLE_ATTENTIONS: class LlamaAttention_Adapted(LlamaAttention): def __init__(self, *args, **kwargs): - if 'mode' in kwargs: - self.mode = kwargs['mode'] - kwargs.pop("mode") - else: - self.mode = "sdpa" + self.mode = kwargs.pop("mode", "sdpa") if self.mode == "math": self.mode = torch.nn.attention.SDPBackend.MATH @@ -301,4 +302,160 @@ class LlamaAttention_Adapted(LlamaAttention): attn_output = self.o_proj(attn_output) - return attn_output, attn_scores, past_key_value \ No newline at end of file + return attn_output, attn_scores, past_key_value + +class LlamaModel_Adapted(LlamaModel): + def __init__(self, *args, **kwargs): + self.layer_dropout_p = kwargs.pop("layer_dropout_p", 0.1) + self.early_exit_scale = kwargs.pop("early_exit_scale", 0.1) + + super().__init__(*args, **kwargs) + + self.layers_n = len(self.layers) + def dropoff_layer( self, l ): + if not self.training: + return False + + # this could probably a LUT but I'm not fiending for aggressive mal-optimizations + D = math.exp((l * LN_2) / (self.layers_n - 1)) - 1 + P = D * self.layer_dropout_p + return random.random() < P + + # to-do: properly implement this per the paper + # this probably is a function of layer number and training step to decide what layer to apply layerskip to for training + def cirriculum( self, l, t=0 ): + return 1 # self.layers_n - 1 + + def early_exit_loss( self, losses, t=0 ): + return sum([ self.normalized_per_layer_loss_scale( l, t ) * losses[l] for l in range(0, self.layers_n) ]) + + def normalized_per_layer_loss_scale( self, l, t=0 ): + return (self.cirriculum(l, t) * self.early_exit_factor( l )) / (sum([ self.cirriculum(l, t) * self.early_exit_factor( i ) for i in range(0, self.layers_n) ])) + + def early_exit_factor( self, l ): + if 0 <= l and l < self.layers_n: + return self.early_exit_scale * sum([ i for i in range(0, l) ]) + return self.layers_n - 1 + self.early_exit_scale * sum([ i for i in range(0, self.layers_n - 1) ]) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for l, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + if not self.dropoff_layer( l ): + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) \ No newline at end of file diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 5297d01..d2ea128 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -38,8 +38,9 @@ from ..emb.qnt import encode_as_embedding # yuck, kind of needed from ..data import get_task_symmap -Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions']) +Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions', 'hidden_states']) Sampled = namedtuple('Sampled', ['out', 'scores', 'entropy']) # these seem more elegant than a dict +LossStats = namedtuple('LossStats', ['loss', 'stats']) """ from ..utils.pattern import DelayedPatternProvider, VALLEPattern @@ -442,6 +443,7 @@ class Base(nn.Module): audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else "" unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True interleave = self.config.experimental.interleave if self.config is not None else False + layerskip = self.config.experimental.layerskip if self.config is not None else False n_tasks = self.config.tasks if self.config is not None else 8 n_langs = self.config.langs if self.config is not None else 2 @@ -469,6 +471,7 @@ class Base(nn.Module): self.unified_position_ids = unified_position_ids self.interleave = interleave + self.layerskip = layerskip self.text_emb = Embedding(n_text_tokens, d_model) self.langs_emb = None @@ -601,8 +604,9 @@ class Base(nn.Module): use_reentrant=False )) elif self.arch_type == "llama": + LlamaClass = LlamaModel_Adapted if self.layerskip else LlamaModel if n_experts <= 1: - self.model = LlamaModel(LlamaConfig( + self.model = LlamaClass(LlamaConfig( vocab_size=n_resp_tokens, hidden_size=d_model, max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds @@ -814,12 +818,14 @@ class Base(nn.Module): state = None, output_attentions = False, + output_hidden_states = False, ): x = inputs m = mask.squeeze(-1).int() aux_loss = None attentions = None + hidden_states = None # HF transformer derived model if self.arch_type in ["llama", "mistral", "mixtral"]: @@ -830,6 +836,7 @@ class Base(nn.Module): position_ids=position_ids, use_cache=not self.training, output_attentions=output_attentions, + output_hidden_states=output_hidden_states, return_dict=True, ) if self.n_experts > 1 and self.training: @@ -846,6 +853,9 @@ class Base(nn.Module): if output_attentions: attentions = output["attentions"] + if output_hidden_states: + hidden_states = output["hidden_states"] + if self.n_experts > 1 and self.training: router_logits = output["aux_loss"] aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok ) @@ -904,11 +914,19 @@ class Base(nn.Module): x = x[0] + # process it into a format that I like + if output_hidden_states: + hidden_states = [ x if i == self.n_layers - 1 else self.model.norm(output.hidden_states[i]) for i in range( self.n_layers ) ] + # output projection layer with masking if self.classifier is not None: x = self.classifier(x) * mask + + if output.hidden_states: + for i in range( self.n_layers ): + hidden_states[i] = self.classifier(hidden_states[i]) * m - return Logits(x, state, aux_loss, attentions) + return Logits(x, state, aux_loss, attentions, hidden_states) # takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation def inputs( @@ -1217,6 +1235,9 @@ class Base(nn.Module): quant_levels: int | list[int] | Tensor | None = None, ): + loss = dict(ce = dict()) + stats = dict(acc = dict()) + device = logits[0].device special_tasks = [ "len", "stt" ] summed_embeddings_task = [ "stt" ] @@ -1285,23 +1306,22 @@ class Base(nn.Module): if False: target = torch.cat( target_list ) inputs = torch.cat( logits ) - self.loss = dict( - # "nll" was in the original implementation and should actually just be called something else + loss = dict( nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index ) ) - self.stats = self.metrics( inputs, targets, classifier_quant_levels ) if self.metrics is not None else dict( + stats = self.metrics( inputs, targets, classifier_quant_levels ) if self.metrics is not None else dict( acc = self.accuracy_metric( inputs, target ), # precision = self.precision_metric( inputs, target ), ) else: - self.loss = dict( + loss = dict( nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size ) - self.stats = self.metrics( logits, target_list, classifier_quant_levels ) if self.metrics is not None else dict( + stats = self.metrics( logits, target_list, classifier_quant_levels ) if self.metrics is not None else dict( acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size ) - return + return LossStats(loss, stats) """ # considerations: @@ -1311,9 +1331,6 @@ class Base(nn.Module): # + extra logic might be required to instead offset from the end for the resp, rather than fit snuggly # + this might just be a spook since the odds the very first token of the AR mattering is slim (although I swear I hear a very brief audio pop sometimes) """ - self.loss = dict() - self.stats = dict(acc = dict()) - info = {} batch_size = len( inputs ) @@ -1385,17 +1402,19 @@ class Base(nn.Module): if False: targets = torch.cat( batch["targets"] ).long() inputs = torch.cat( batch["logits"] ) - self.loss[name] = F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor - self.stats["acc"][name] = self.accuracy_metric( inputs, targets ) + loss[name] = F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor + stats["acc"][name] = self.accuracy_metric( inputs, targets ) # probably consumes less memory due to not having to allocate memory # this method also opens the way to scale loss per RVQ level (although it shouldn't really be needed) else: - self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size + loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size if self.metrics is not None: metrics = self.metrics( batch["logits"], batch["targets"], classifier_quant_levels ) - self.stats["acc"][name] = metrics["acc"] + stats["acc"][name] = metrics["acc"] else: - self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size + stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size + + return LossStats(loss, stats) def forward( self, @@ -1404,6 +1423,7 @@ class Base(nn.Module): quant_levels: int | list[int] | Tensor | None = None, state: dict | list | None = None, output_attentions = False, + output_hidden_states = False, ): x_list = self.inputs_to_embeddings( inputs, quant_levels ) x, m = list_to_tensor(x_list) @@ -1412,10 +1432,12 @@ class Base(nn.Module): device = x.device batch_size = len(x_list) - # pure AR if quant_levels is None: quant_levels = [ 0 for _ in range(batch_size) ] + + if self.layerskip: + output_hidden_states = True # pad our input and mask, but retain the original length by doing it after if self.l_padding and x.shape[1] % self.l_padding != 0: @@ -1440,9 +1462,11 @@ class Base(nn.Module): state=state, position_ids=position_ids, output_attentions = output_attentions, + output_hidden_states = output_hidden_states, ) logits = output.logits + hidden_states = output.hidden_states # to-do: piece-wise classification, now that there's a head for text # although again, one single monolithic head would be preferable instead...... @@ -1451,19 +1475,52 @@ class Base(nn.Module): classifier_quant_levels = [ -1 if inputs[i][0][-1] in special_tasks else l for i, l in enumerate( quant_levels ) ] logits = self.classifiers(logits, levels = classifier_quant_levels) * m + if hidden_states is not None: + for i in range( self.n_layers ): + hidden_states[i] = self.classifiers(hidden_states[i], levels = classifier_quant_levels) * m + # Remove padding logits = [ hi[:li] for hi, li in zip(logits, map(len, x_list)) ] # compute loss if the target is given if training: - self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels ) + if not self.layerskip: + loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels ) + else: + self.loss = {} + self.stats = {} + + for i in range( self.n_layers ): + # remove padding + hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ] + loss, stats = self.calc_loss( inputs=inputs, logits=hidden_states[i], quant_levels=quant_levels ) + + for k, v in loss.items(): + if k not in self.loss: + self.loss[k] = [] + self.loss[k].append( v ) + + for k, v in stats.items(): + if k not in self.stats: + self.stats[k] = [] + self.stats[k].append( v ) + + for k, v in self.loss.items(): + self.loss[k] = self.model.early_exit_loss( losses=v ) + + for k, v in self.stats.items(): + self.stats[k] = sum( v ) / len( v ) + # include any additional losses (for example: MoE router) if output.aux_loss is not None: - self.loss["aux_loss"] = output.aux_loss + loss["aux_loss"] = output.aux_loss + + self.loss = loss + self.stats = stats # rewrap, because we're modifying the logits here - return Logits(logits, output.state, output.aux_loss, output.attentions) + return Logits(logits, output.state, output.aux_loss, output.attentions, hidden_states) def sample( self, diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index b0be902..89b0ebe 100644 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -60,7 +60,7 @@ class NAR(Base): # is training if resps_list is not None: - p_len_task = self.config.experimental.p_len_train if self.config is not None else 0.05 + len_train_p = self.config.experimental.len_train_p if self.config is not None else 0.05 n_levels_set = {r.shape[-1] for r in resps_list} n_levels = next(iter(n_levels_set)) @@ -69,7 +69,7 @@ class NAR(Base): # to-do: make this YAML configurable def sample_task(): - return "len" if random.random() < p_len_task else "tts" + return "len" if random.random() < len_train_p else "tts" # generate task list to train against task_list = [ sample_task() for _ in range(batch_size) ]