layer skip training implemented (need to gut the inferencing from the repo, and to actually see if the model can benefit from this)

This commit is contained in:
mrq 2024-10-30 20:05:45 -05:00
parent 4049f51ba9
commit a22534e8f4
7 changed files with 251 additions and 39 deletions

View File

@ -255,8 +255,10 @@ class ModelExperimentalSettings:
# it just seems like a bitch to try and train something worthwhile with it, since there's crackles every other token
# RetNet's chunked inferencing might be a better place for this
p_len_train: float = 0.05 # odds of injecting a "len" task within the model for NAR-len
len_train_p: float = 0.05 # odds of injecting a "len" task within the model for NAR-len
# to-to: just incorporate this as a task instead
layerskip: bool = False
# I really need to clean this up
@dataclass()
@ -870,9 +872,6 @@ class Config(BaseConfig):
def format( self, training=True ):
print( self.models )
print( self.loras )
if isinstance(self.dataset, type):
self.dataset = dict()

View File

@ -43,7 +43,6 @@ def main():
parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--model", type=Path, default=None)
parser.add_argument("--lora", type=Path, default=None)
parser.add_argument("--demo-dir", type=Path, default=None)
parser.add_argument("--skip-existing", action="store_true")
@ -245,8 +244,8 @@ def main():
metadata = batch["metadata"]
#text = get_random_prompt() if args.random_prompts else metadata["text"]
text = get_random_prompt() if i >= (num // 2) else metadata["text"]
text = get_random_prompt() if args.random_prompts else metadata["text"]
#text = get_random_prompt() if i >= (num // 2) else metadata["text"]
language = metadata["language"].lower()
prompt = dir / "prompt.wav"

View File

@ -496,7 +496,7 @@ def example_usage():
bos_id, space_id, eos_id = cfg.tokenizer.encode( " " )
#available_tasks = cfg.dataset.tasks_list
available_tasks = ["tts", "stt"]
available_tasks = ["tts"] # , "stt"]
model = AR_NAR(**kwargs).to(device)
steps = 500 # 150 * len(available_tasks) # * cfg.model.experimental.causal_size

View File

@ -30,7 +30,7 @@ except Exception as e:
pass
try:
from .llama import LlamaModel, LlamaConfig, AVAILABLE_ATTENTIONS, LlamaAttention, LlamaAttention_Adapted, LlamaForCausalLM
from .llama import LlamaModel, LlamaModel_Adapted, LlamaConfig, AVAILABLE_ATTENTIONS, LlamaAttention, LlamaAttention_Adapted, LlamaForCausalLM
AVAILABLE_ARCHES.append("llama")
except Exception as e:
ERROR_ARCHES["llama"] = e

View File

@ -3,19 +3,24 @@
import math
import torch
import logging
import random
from typing import Literal, overload, Optional, Tuple
from typing import Literal, overload, Optional, Tuple, Union, List, Unpack
from torch import Tensor, nn
from transformers.cache_utils import Cache
from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
_logger = logging.getLogger(__name__)
AVAILABLE_ATTENTIONS = []
LN_2 = 0.69314718056
try:
from transformers.utils import is_flash_attn_2_available
@ -131,11 +136,7 @@ if AVAILABLE_ATTENTIONS:
class LlamaAttention_Adapted(LlamaAttention):
def __init__(self, *args, **kwargs):
if 'mode' in kwargs:
self.mode = kwargs['mode']
kwargs.pop("mode")
else:
self.mode = "sdpa"
self.mode = kwargs.pop("mode", "sdpa")
if self.mode == "math":
self.mode = torch.nn.attention.SDPBackend.MATH
@ -301,4 +302,160 @@ class LlamaAttention_Adapted(LlamaAttention):
attn_output = self.o_proj(attn_output)
return attn_output, attn_scores, past_key_value
return attn_output, attn_scores, past_key_value
class LlamaModel_Adapted(LlamaModel):
def __init__(self, *args, **kwargs):
self.layer_dropout_p = kwargs.pop("layer_dropout_p", 0.1)
self.early_exit_scale = kwargs.pop("early_exit_scale", 0.1)
super().__init__(*args, **kwargs)
self.layers_n = len(self.layers)
def dropoff_layer( self, l ):
if not self.training:
return False
# this could probably a LUT but I'm not fiending for aggressive mal-optimizations
D = math.exp((l * LN_2) / (self.layers_n - 1)) - 1
P = D * self.layer_dropout_p
return random.random() < P
# to-do: properly implement this per the paper
# this probably is a function of layer number and training step to decide what layer to apply layerskip to for training
def cirriculum( self, l, t=0 ):
return 1 # self.layers_n - 1
def early_exit_loss( self, losses, t=0 ):
return sum([ self.normalized_per_layer_loss_scale( l, t ) * losses[l] for l in range(0, self.layers_n) ])
def normalized_per_layer_loss_scale( self, l, t=0 ):
return (self.cirriculum(l, t) * self.early_exit_factor( l )) / (sum([ self.cirriculum(l, t) * self.early_exit_factor( i ) for i in range(0, self.layers_n) ]))
def early_exit_factor( self, l ):
if 0 <= l and l < self.layers_n:
return self.early_exit_scale * sum([ i for i in range(0, l) ])
return self.layers_n - 1 + self.early_exit_scale * sum([ i for i in range(0, self.layers_n - 1) ])
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for l, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
if not self.dropoff_layer( l ):
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)

View File

@ -38,8 +38,9 @@ from ..emb.qnt import encode_as_embedding
# yuck, kind of needed
from ..data import get_task_symmap
Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions'])
Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions', 'hidden_states'])
Sampled = namedtuple('Sampled', ['out', 'scores', 'entropy']) # these seem more elegant than a dict
LossStats = namedtuple('LossStats', ['loss', 'stats'])
"""
from ..utils.pattern import DelayedPatternProvider, VALLEPattern
@ -442,6 +443,7 @@ class Base(nn.Module):
audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else ""
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
interleave = self.config.experimental.interleave if self.config is not None else False
layerskip = self.config.experimental.layerskip if self.config is not None else False
n_tasks = self.config.tasks if self.config is not None else 8
n_langs = self.config.langs if self.config is not None else 2
@ -469,6 +471,7 @@ class Base(nn.Module):
self.unified_position_ids = unified_position_ids
self.interleave = interleave
self.layerskip = layerskip
self.text_emb = Embedding(n_text_tokens, d_model)
self.langs_emb = None
@ -601,8 +604,9 @@ class Base(nn.Module):
use_reentrant=False
))
elif self.arch_type == "llama":
LlamaClass = LlamaModel_Adapted if self.layerskip else LlamaModel
if n_experts <= 1:
self.model = LlamaModel(LlamaConfig(
self.model = LlamaClass(LlamaConfig(
vocab_size=n_resp_tokens,
hidden_size=d_model,
max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds
@ -814,12 +818,14 @@ class Base(nn.Module):
state = None,
output_attentions = False,
output_hidden_states = False,
):
x = inputs
m = mask.squeeze(-1).int()
aux_loss = None
attentions = None
hidden_states = None
# HF transformer derived model
if self.arch_type in ["llama", "mistral", "mixtral"]:
@ -830,6 +836,7 @@ class Base(nn.Module):
position_ids=position_ids,
use_cache=not self.training,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
)
if self.n_experts > 1 and self.training:
@ -846,6 +853,9 @@ class Base(nn.Module):
if output_attentions:
attentions = output["attentions"]
if output_hidden_states:
hidden_states = output["hidden_states"]
if self.n_experts > 1 and self.training:
router_logits = output["aux_loss"]
aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok )
@ -904,11 +914,19 @@ class Base(nn.Module):
x = x[0]
# process it into a format that I like
if output_hidden_states:
hidden_states = [ x if i == self.n_layers - 1 else self.model.norm(output.hidden_states[i]) for i in range( self.n_layers ) ]
# output projection layer with masking
if self.classifier is not None:
x = self.classifier(x) * mask
if output.hidden_states:
for i in range( self.n_layers ):
hidden_states[i] = self.classifier(hidden_states[i]) * m
return Logits(x, state, aux_loss, attentions)
return Logits(x, state, aux_loss, attentions, hidden_states)
# takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation
def inputs(
@ -1217,6 +1235,9 @@ class Base(nn.Module):
quant_levels: int | list[int] | Tensor | None = None,
):
loss = dict(ce = dict())
stats = dict(acc = dict())
device = logits[0].device
special_tasks = [ "len", "stt" ]
summed_embeddings_task = [ "stt" ]
@ -1285,23 +1306,22 @@ class Base(nn.Module):
if False:
target = torch.cat( target_list )
inputs = torch.cat( logits )
self.loss = dict(
# "nll" was in the original implementation and should actually just be called something else
loss = dict(
nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index )
)
self.stats = self.metrics( inputs, targets, classifier_quant_levels ) if self.metrics is not None else dict(
stats = self.metrics( inputs, targets, classifier_quant_levels ) if self.metrics is not None else dict(
acc = self.accuracy_metric( inputs, target ),
# precision = self.precision_metric( inputs, target ),
)
else:
self.loss = dict(
loss = dict(
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size
)
self.stats = self.metrics( logits, target_list, classifier_quant_levels ) if self.metrics is not None else dict(
stats = self.metrics( logits, target_list, classifier_quant_levels ) if self.metrics is not None else dict(
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size
)
return
return LossStats(loss, stats)
"""
# considerations:
@ -1311,9 +1331,6 @@ class Base(nn.Module):
# + extra logic might be required to instead offset from the end for the resp, rather than fit snuggly
# + this might just be a spook since the odds the very first token of the AR mattering is slim (although I swear I hear a very brief audio pop sometimes)
"""
self.loss = dict()
self.stats = dict(acc = dict())
info = {}
batch_size = len( inputs )
@ -1385,17 +1402,19 @@ class Base(nn.Module):
if False:
targets = torch.cat( batch["targets"] ).long()
inputs = torch.cat( batch["logits"] )
self.loss[name] = F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor
self.stats["acc"][name] = self.accuracy_metric( inputs, targets )
loss[name] = F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor
stats["acc"][name] = self.accuracy_metric( inputs, targets )
# probably consumes less memory due to not having to allocate memory
# this method also opens the way to scale loss per RVQ level (although it shouldn't really be needed)
else:
self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size
loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size
if self.metrics is not None:
metrics = self.metrics( batch["logits"], batch["targets"], classifier_quant_levels )
self.stats["acc"][name] = metrics["acc"]
stats["acc"][name] = metrics["acc"]
else:
self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size
stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size
return LossStats(loss, stats)
def forward(
self,
@ -1404,6 +1423,7 @@ class Base(nn.Module):
quant_levels: int | list[int] | Tensor | None = None,
state: dict | list | None = None,
output_attentions = False,
output_hidden_states = False,
):
x_list = self.inputs_to_embeddings( inputs, quant_levels )
x, m = list_to_tensor(x_list)
@ -1412,10 +1432,12 @@ class Base(nn.Module):
device = x.device
batch_size = len(x_list)
# pure AR
if quant_levels is None:
quant_levels = [ 0 for _ in range(batch_size) ]
if self.layerskip:
output_hidden_states = True
# pad our input and mask, but retain the original length by doing it after
if self.l_padding and x.shape[1] % self.l_padding != 0:
@ -1440,9 +1462,11 @@ class Base(nn.Module):
state=state,
position_ids=position_ids,
output_attentions = output_attentions,
output_hidden_states = output_hidden_states,
)
logits = output.logits
hidden_states = output.hidden_states
# to-do: piece-wise classification, now that there's a head for text
# although again, one single monolithic head would be preferable instead......
@ -1451,19 +1475,52 @@ class Base(nn.Module):
classifier_quant_levels = [ -1 if inputs[i][0][-1] in special_tasks else l for i, l in enumerate( quant_levels ) ]
logits = self.classifiers(logits, levels = classifier_quant_levels) * m
if hidden_states is not None:
for i in range( self.n_layers ):
hidden_states[i] = self.classifiers(hidden_states[i], levels = classifier_quant_levels) * m
# Remove padding
logits = [ hi[:li] for hi, li in zip(logits, map(len, x_list)) ]
# compute loss if the target is given
if training:
self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
if not self.layerskip:
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
else:
self.loss = {}
self.stats = {}
for i in range( self.n_layers ):
# remove padding
hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ]
loss, stats = self.calc_loss( inputs=inputs, logits=hidden_states[i], quant_levels=quant_levels )
for k, v in loss.items():
if k not in self.loss:
self.loss[k] = []
self.loss[k].append( v )
for k, v in stats.items():
if k not in self.stats:
self.stats[k] = []
self.stats[k].append( v )
for k, v in self.loss.items():
self.loss[k] = self.model.early_exit_loss( losses=v )
for k, v in self.stats.items():
self.stats[k] = sum( v ) / len( v )
# include any additional losses (for example: MoE router)
if output.aux_loss is not None:
self.loss["aux_loss"] = output.aux_loss
loss["aux_loss"] = output.aux_loss
self.loss = loss
self.stats = stats
# rewrap, because we're modifying the logits here
return Logits(logits, output.state, output.aux_loss, output.attentions)
return Logits(logits, output.state, output.aux_loss, output.attentions, hidden_states)
def sample(
self,

View File

@ -60,7 +60,7 @@ class NAR(Base):
# is training
if resps_list is not None:
p_len_task = self.config.experimental.p_len_train if self.config is not None else 0.05
len_train_p = self.config.experimental.len_train_p if self.config is not None else 0.05
n_levels_set = {r.shape[-1] for r in resps_list}
n_levels = next(iter(n_levels_set))
@ -69,7 +69,7 @@ class NAR(Base):
# to-do: make this YAML configurable
def sample_task():
return "len" if random.random() < p_len_task else "tts"
return "len" if random.random() < len_train_p else "tts"
# generate task list to train against
task_list = [ sample_task() for _ in range(batch_size) ]