third time's the charm (for some reason it escaped me that I should treat early exit loss as an aux_loss to be used with the normal loss, as if I was training a MoE's router)
This commit is contained in:
parent
76ebef45dc
commit
9b6c57bc57
|
@ -260,7 +260,7 @@ class ModelExperimentalSettings:
|
|||
|
||||
layerskip: bool = False # layerskip compatible model (or training for)
|
||||
#layerskip_rvq_levels: list = field(default_factory=lambda: []) # RVQ levels to train / inference layerskip for (to-do: implement, see if it matters)
|
||||
layerskip_r: int = 2 # number of layers to factor into early-exit loss calc
|
||||
layerskip_r: int = 6 # number of layers to factor into early-exit loss calc
|
||||
layerskip_p_max: float = 0.1 # maximum probabilty to dropout the last layer, used for calculating layer dropout probabilities
|
||||
layerskip_e_scale: float = 0.2 # early-exit loss scalar value
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ class AR_NAR(Base):
|
|||
tone_list: list[Tensor] | None = None,
|
||||
len_list: list[Tensor] | None = None,
|
||||
|
||||
training: bool | None = None,
|
||||
training: bool | int | None = None,
|
||||
|
||||
max_steps: int = 1000,
|
||||
max_levels: int = 0,
|
||||
|
@ -97,11 +97,12 @@ class AR_NAR(Base):
|
|||
n_levels_set = {r.shape[-1] for r in resps_list}
|
||||
n_levels = next(iter(n_levels_set))
|
||||
|
||||
# implicit
|
||||
if training is None:
|
||||
training = n_levels == self.n_resp_levels
|
||||
training = 0 if n_levels == self.n_resp_levels else None
|
||||
|
||||
# is training
|
||||
if training:
|
||||
if training is not None:
|
||||
# specifies how to sample probabilities of which RVQ levels to train against
|
||||
rvq_levels_p = self.config.experimental.rvq_levels_p if self.config is not None else "equal"
|
||||
# determines which RVQ level to target per batch
|
||||
|
|
|
@ -371,7 +371,7 @@ class LlamaModel_Adapted(LlamaModel):
|
|||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
_logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
@ -387,7 +387,7 @@ class LlamaModel_Adapted(LlamaModel):
|
|||
past_key_values = DynamicCache()
|
||||
else:
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
logger.warning_once(
|
||||
_logger.warning_once(
|
||||
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||
|
|
|
@ -476,7 +476,6 @@ class Base(nn.Module):
|
|||
self.unified_position_ids = unified_position_ids
|
||||
self.interleave = interleave
|
||||
self.layerskip = layerskip
|
||||
self.training_step = 0
|
||||
|
||||
self.text_emb = Embedding(n_text_tokens, d_model)
|
||||
self.langs_emb = None
|
||||
|
@ -927,14 +926,17 @@ class Base(nn.Module):
|
|||
# process it into a format that I like
|
||||
if output_hidden_states:
|
||||
# hidden_states is actually layers + 1, as hidden_states[0] == embedding...........
|
||||
hidden_states = [ x if i == self.n_layers else self.model.norm(output.hidden_states[i]) for i in range( 1, self.n_layers + 1 ) ]
|
||||
hidden_states = [ state for state in hidden_states[1:] ]
|
||||
# apply normalization to these states (to-do: check if this matters)
|
||||
# but skip the last state, as it already is normalized
|
||||
hidden_states = [ x if i == self.n_layers - 1 else self.model.norm(output.hidden_states[i]) for i, state in enumerate( hidden_states ) ]
|
||||
|
||||
# output projection layer with masking
|
||||
if self.classifier is not None:
|
||||
x = self.classifier(x) * mask
|
||||
|
||||
if output.hidden_states:
|
||||
for i in range( self.n_layers ):
|
||||
for i, state in enumerate( hidden_states ):
|
||||
hidden_states[i] = self.classifier(hidden_states[i]) * m
|
||||
|
||||
return Logits(x, state, aux_loss, attentions, hidden_states)
|
||||
|
@ -1325,6 +1327,7 @@ class Base(nn.Module):
|
|||
# precision = self.precision_metric( inputs, target ),
|
||||
)
|
||||
else:
|
||||
# nll being natural log likelihood :)))) (I don't know why this completely escaped be originally with thinking it meant something else)
|
||||
loss = dict(
|
||||
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size
|
||||
)
|
||||
|
@ -1487,42 +1490,41 @@ class Base(nn.Module):
|
|||
logits = self.classifiers(logits, levels = classifier_quant_levels) * m
|
||||
|
||||
if hidden_states is not None:
|
||||
for i in range( self.n_layers ):
|
||||
for i, state in enumerate( hidden_states ):
|
||||
hidden_states[i] = self.classifiers(hidden_states[i], levels = classifier_quant_levels) * m
|
||||
|
||||
# Remove padding
|
||||
logits = [ hi[:li] for hi, li in zip(logits, map(len, x_list)) ]
|
||||
|
||||
if hidden_states is not None:
|
||||
for i, state in enumerate( hidden_states ):
|
||||
# remove padding
|
||||
hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ]
|
||||
|
||||
# compute loss if the target is given
|
||||
if training:
|
||||
if not self.layerskip:
|
||||
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
|
||||
else:
|
||||
self.loss = {}
|
||||
self.stats = {}
|
||||
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
|
||||
|
||||
# compute it as an aux-loss
|
||||
if self.layerskip:
|
||||
early_exit_loss = {}
|
||||
if not hasattr( self, "training_steps" ):
|
||||
self.training_steps = 0
|
||||
|
||||
for i in range( self.n_layers ):
|
||||
# remove padding
|
||||
hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ]
|
||||
for i, state in enumerate( hidden_states ):
|
||||
loss, stats = self.calc_loss( inputs=inputs, logits=hidden_states[i], quant_levels=quant_levels )
|
||||
|
||||
for k, v in loss.items():
|
||||
if k not in self.loss:
|
||||
self.loss[k] = []
|
||||
self.loss[k].append( v )
|
||||
K = f'early_exit.{k}'
|
||||
if K not in early_exit_loss:
|
||||
early_exit_loss[K] = []
|
||||
early_exit_loss[K].append( v )
|
||||
|
||||
for k, v in stats.items():
|
||||
if k not in self.stats:
|
||||
self.stats[k] = []
|
||||
self.stats[k].append( v )
|
||||
for k, v in early_exit_loss.items():
|
||||
loss[k] = self.model.early_exit_loss( losses=v, t=self.training_steps )
|
||||
|
||||
for k, v in self.loss.items():
|
||||
self.loss[k] = self.model.early_exit_loss( losses=v, t=self.training_step )
|
||||
# ick
|
||||
self.training_step += 1
|
||||
|
||||
for k, v in self.stats.items():
|
||||
self.stats[k] = sum( v ) / len( v )
|
||||
# to-do: instead make the cirriculum rely on samples processed instead of steps
|
||||
self.training_steps += 1 # batch_size
|
||||
|
||||
# include any additional losses (for example: MoE router)
|
||||
if output.aux_loss is not None:
|
||||
|
|
Loading…
Reference in New Issue
Block a user