third time's the charm (for some reason it escaped me that I should treat early exit loss as an aux_loss to be used with the normal loss, as if I was training a MoE's router)

This commit is contained in:
mrq 2024-11-01 12:50:37 -05:00
parent 76ebef45dc
commit 9b6c57bc57
4 changed files with 35 additions and 32 deletions

View File

@ -260,7 +260,7 @@ class ModelExperimentalSettings:
layerskip: bool = False # layerskip compatible model (or training for)
#layerskip_rvq_levels: list = field(default_factory=lambda: []) # RVQ levels to train / inference layerskip for (to-do: implement, see if it matters)
layerskip_r: int = 2 # number of layers to factor into early-exit loss calc
layerskip_r: int = 6 # number of layers to factor into early-exit loss calc
layerskip_p_max: float = 0.1 # maximum probabilty to dropout the last layer, used for calculating layer dropout probabilities
layerskip_e_scale: float = 0.2 # early-exit loss scalar value

View File

@ -43,7 +43,7 @@ class AR_NAR(Base):
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
training: bool | None = None,
training: bool | int | None = None,
max_steps: int = 1000,
max_levels: int = 0,
@ -97,11 +97,12 @@ class AR_NAR(Base):
n_levels_set = {r.shape[-1] for r in resps_list}
n_levels = next(iter(n_levels_set))
# implicit
if training is None:
training = n_levels == self.n_resp_levels
training = 0 if n_levels == self.n_resp_levels else None
# is training
if training:
if training is not None:
# specifies how to sample probabilities of which RVQ levels to train against
rvq_levels_p = self.config.experimental.rvq_levels_p if self.config is not None else "equal"
# determines which RVQ level to target per batch

View File

@ -371,7 +371,7 @@ class LlamaModel_Adapted(LlamaModel):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
_logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
@ -387,7 +387,7 @@ class LlamaModel_Adapted(LlamaModel):
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
_logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"

View File

@ -476,7 +476,6 @@ class Base(nn.Module):
self.unified_position_ids = unified_position_ids
self.interleave = interleave
self.layerskip = layerskip
self.training_step = 0
self.text_emb = Embedding(n_text_tokens, d_model)
self.langs_emb = None
@ -927,14 +926,17 @@ class Base(nn.Module):
# process it into a format that I like
if output_hidden_states:
# hidden_states is actually layers + 1, as hidden_states[0] == embedding...........
hidden_states = [ x if i == self.n_layers else self.model.norm(output.hidden_states[i]) for i in range( 1, self.n_layers + 1 ) ]
hidden_states = [ state for state in hidden_states[1:] ]
# apply normalization to these states (to-do: check if this matters)
# but skip the last state, as it already is normalized
hidden_states = [ x if i == self.n_layers - 1 else self.model.norm(output.hidden_states[i]) for i, state in enumerate( hidden_states ) ]
# output projection layer with masking
if self.classifier is not None:
x = self.classifier(x) * mask
if output.hidden_states:
for i in range( self.n_layers ):
for i, state in enumerate( hidden_states ):
hidden_states[i] = self.classifier(hidden_states[i]) * m
return Logits(x, state, aux_loss, attentions, hidden_states)
@ -1325,6 +1327,7 @@ class Base(nn.Module):
# precision = self.precision_metric( inputs, target ),
)
else:
# nll being natural log likelihood :)))) (I don't know why this completely escaped be originally with thinking it meant something else)
loss = dict(
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size
)
@ -1487,42 +1490,41 @@ class Base(nn.Module):
logits = self.classifiers(logits, levels = classifier_quant_levels) * m
if hidden_states is not None:
for i in range( self.n_layers ):
for i, state in enumerate( hidden_states ):
hidden_states[i] = self.classifiers(hidden_states[i], levels = classifier_quant_levels) * m
# Remove padding
logits = [ hi[:li] for hi, li in zip(logits, map(len, x_list)) ]
if hidden_states is not None:
for i, state in enumerate( hidden_states ):
# remove padding
hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ]
# compute loss if the target is given
if training:
if not self.layerskip:
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
else:
self.loss = {}
self.stats = {}
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
# compute it as an aux-loss
if self.layerskip:
early_exit_loss = {}
if not hasattr( self, "training_steps" ):
self.training_steps = 0
for i in range( self.n_layers ):
# remove padding
hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ]
for i, state in enumerate( hidden_states ):
loss, stats = self.calc_loss( inputs=inputs, logits=hidden_states[i], quant_levels=quant_levels )
for k, v in loss.items():
if k not in self.loss:
self.loss[k] = []
self.loss[k].append( v )
K = f'early_exit.{k}'
if K not in early_exit_loss:
early_exit_loss[K] = []
early_exit_loss[K].append( v )
for k, v in stats.items():
if k not in self.stats:
self.stats[k] = []
self.stats[k].append( v )
for k, v in early_exit_loss.items():
loss[k] = self.model.early_exit_loss( losses=v, t=self.training_steps )
for k, v in self.loss.items():
self.loss[k] = self.model.early_exit_loss( losses=v, t=self.training_step )
# ick
self.training_step += 1
for k, v in self.stats.items():
self.stats[k] = sum( v ) / len( v )
# to-do: instead make the cirriculum rely on samples processed instead of steps
self.training_steps += 1 # batch_size
# include any additional losses (for example: MoE router)
if output.aux_loss is not None: