third time's the charm (for some reason it escaped me that I should treat early exit loss as an aux_loss to be used with the normal loss, as if I was training a MoE's router)
This commit is contained in:
parent
76ebef45dc
commit
9b6c57bc57
|
@ -260,7 +260,7 @@ class ModelExperimentalSettings:
|
||||||
|
|
||||||
layerskip: bool = False # layerskip compatible model (or training for)
|
layerskip: bool = False # layerskip compatible model (or training for)
|
||||||
#layerskip_rvq_levels: list = field(default_factory=lambda: []) # RVQ levels to train / inference layerskip for (to-do: implement, see if it matters)
|
#layerskip_rvq_levels: list = field(default_factory=lambda: []) # RVQ levels to train / inference layerskip for (to-do: implement, see if it matters)
|
||||||
layerskip_r: int = 2 # number of layers to factor into early-exit loss calc
|
layerskip_r: int = 6 # number of layers to factor into early-exit loss calc
|
||||||
layerskip_p_max: float = 0.1 # maximum probabilty to dropout the last layer, used for calculating layer dropout probabilities
|
layerskip_p_max: float = 0.1 # maximum probabilty to dropout the last layer, used for calculating layer dropout probabilities
|
||||||
layerskip_e_scale: float = 0.2 # early-exit loss scalar value
|
layerskip_e_scale: float = 0.2 # early-exit loss scalar value
|
||||||
|
|
||||||
|
|
|
@ -43,7 +43,7 @@ class AR_NAR(Base):
|
||||||
tone_list: list[Tensor] | None = None,
|
tone_list: list[Tensor] | None = None,
|
||||||
len_list: list[Tensor] | None = None,
|
len_list: list[Tensor] | None = None,
|
||||||
|
|
||||||
training: bool | None = None,
|
training: bool | int | None = None,
|
||||||
|
|
||||||
max_steps: int = 1000,
|
max_steps: int = 1000,
|
||||||
max_levels: int = 0,
|
max_levels: int = 0,
|
||||||
|
@ -97,11 +97,12 @@ class AR_NAR(Base):
|
||||||
n_levels_set = {r.shape[-1] for r in resps_list}
|
n_levels_set = {r.shape[-1] for r in resps_list}
|
||||||
n_levels = next(iter(n_levels_set))
|
n_levels = next(iter(n_levels_set))
|
||||||
|
|
||||||
|
# implicit
|
||||||
if training is None:
|
if training is None:
|
||||||
training = n_levels == self.n_resp_levels
|
training = 0 if n_levels == self.n_resp_levels else None
|
||||||
|
|
||||||
# is training
|
# is training
|
||||||
if training:
|
if training is not None:
|
||||||
# specifies how to sample probabilities of which RVQ levels to train against
|
# specifies how to sample probabilities of which RVQ levels to train against
|
||||||
rvq_levels_p = self.config.experimental.rvq_levels_p if self.config is not None else "equal"
|
rvq_levels_p = self.config.experimental.rvq_levels_p if self.config is not None else "equal"
|
||||||
# determines which RVQ level to target per batch
|
# determines which RVQ level to target per batch
|
||||||
|
|
|
@ -371,7 +371,7 @@ class LlamaModel_Adapted(LlamaModel):
|
||||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training and use_cache:
|
if self.gradient_checkpointing and self.training and use_cache:
|
||||||
logger.warning_once(
|
_logger.warning_once(
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
@ -387,7 +387,7 @@ class LlamaModel_Adapted(LlamaModel):
|
||||||
past_key_values = DynamicCache()
|
past_key_values = DynamicCache()
|
||||||
else:
|
else:
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
logger.warning_once(
|
_logger.warning_once(
|
||||||
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
|
|
|
@ -476,7 +476,6 @@ class Base(nn.Module):
|
||||||
self.unified_position_ids = unified_position_ids
|
self.unified_position_ids = unified_position_ids
|
||||||
self.interleave = interleave
|
self.interleave = interleave
|
||||||
self.layerskip = layerskip
|
self.layerskip = layerskip
|
||||||
self.training_step = 0
|
|
||||||
|
|
||||||
self.text_emb = Embedding(n_text_tokens, d_model)
|
self.text_emb = Embedding(n_text_tokens, d_model)
|
||||||
self.langs_emb = None
|
self.langs_emb = None
|
||||||
|
@ -927,14 +926,17 @@ class Base(nn.Module):
|
||||||
# process it into a format that I like
|
# process it into a format that I like
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
# hidden_states is actually layers + 1, as hidden_states[0] == embedding...........
|
# hidden_states is actually layers + 1, as hidden_states[0] == embedding...........
|
||||||
hidden_states = [ x if i == self.n_layers else self.model.norm(output.hidden_states[i]) for i in range( 1, self.n_layers + 1 ) ]
|
hidden_states = [ state for state in hidden_states[1:] ]
|
||||||
|
# apply normalization to these states (to-do: check if this matters)
|
||||||
|
# but skip the last state, as it already is normalized
|
||||||
|
hidden_states = [ x if i == self.n_layers - 1 else self.model.norm(output.hidden_states[i]) for i, state in enumerate( hidden_states ) ]
|
||||||
|
|
||||||
# output projection layer with masking
|
# output projection layer with masking
|
||||||
if self.classifier is not None:
|
if self.classifier is not None:
|
||||||
x = self.classifier(x) * mask
|
x = self.classifier(x) * mask
|
||||||
|
|
||||||
if output.hidden_states:
|
if output.hidden_states:
|
||||||
for i in range( self.n_layers ):
|
for i, state in enumerate( hidden_states ):
|
||||||
hidden_states[i] = self.classifier(hidden_states[i]) * m
|
hidden_states[i] = self.classifier(hidden_states[i]) * m
|
||||||
|
|
||||||
return Logits(x, state, aux_loss, attentions, hidden_states)
|
return Logits(x, state, aux_loss, attentions, hidden_states)
|
||||||
|
@ -1325,6 +1327,7 @@ class Base(nn.Module):
|
||||||
# precision = self.precision_metric( inputs, target ),
|
# precision = self.precision_metric( inputs, target ),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# nll being natural log likelihood :)))) (I don't know why this completely escaped be originally with thinking it meant something else)
|
||||||
loss = dict(
|
loss = dict(
|
||||||
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size
|
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size
|
||||||
)
|
)
|
||||||
|
@ -1487,42 +1490,41 @@ class Base(nn.Module):
|
||||||
logits = self.classifiers(logits, levels = classifier_quant_levels) * m
|
logits = self.classifiers(logits, levels = classifier_quant_levels) * m
|
||||||
|
|
||||||
if hidden_states is not None:
|
if hidden_states is not None:
|
||||||
for i in range( self.n_layers ):
|
for i, state in enumerate( hidden_states ):
|
||||||
hidden_states[i] = self.classifiers(hidden_states[i], levels = classifier_quant_levels) * m
|
hidden_states[i] = self.classifiers(hidden_states[i], levels = classifier_quant_levels) * m
|
||||||
|
|
||||||
# Remove padding
|
# Remove padding
|
||||||
logits = [ hi[:li] for hi, li in zip(logits, map(len, x_list)) ]
|
logits = [ hi[:li] for hi, li in zip(logits, map(len, x_list)) ]
|
||||||
|
|
||||||
# compute loss if the target is given
|
if hidden_states is not None:
|
||||||
if training:
|
for i, state in enumerate( hidden_states ):
|
||||||
if not self.layerskip:
|
|
||||||
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
|
|
||||||
else:
|
|
||||||
self.loss = {}
|
|
||||||
self.stats = {}
|
|
||||||
|
|
||||||
for i in range( self.n_layers ):
|
|
||||||
# remove padding
|
# remove padding
|
||||||
hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ]
|
hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ]
|
||||||
|
|
||||||
|
# compute loss if the target is given
|
||||||
|
if training:
|
||||||
|
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
|
||||||
|
|
||||||
|
# compute it as an aux-loss
|
||||||
|
if self.layerskip:
|
||||||
|
early_exit_loss = {}
|
||||||
|
if not hasattr( self, "training_steps" ):
|
||||||
|
self.training_steps = 0
|
||||||
|
|
||||||
|
for i, state in enumerate( hidden_states ):
|
||||||
loss, stats = self.calc_loss( inputs=inputs, logits=hidden_states[i], quant_levels=quant_levels )
|
loss, stats = self.calc_loss( inputs=inputs, logits=hidden_states[i], quant_levels=quant_levels )
|
||||||
|
|
||||||
for k, v in loss.items():
|
for k, v in loss.items():
|
||||||
if k not in self.loss:
|
K = f'early_exit.{k}'
|
||||||
self.loss[k] = []
|
if K not in early_exit_loss:
|
||||||
self.loss[k].append( v )
|
early_exit_loss[K] = []
|
||||||
|
early_exit_loss[K].append( v )
|
||||||
|
|
||||||
for k, v in stats.items():
|
for k, v in early_exit_loss.items():
|
||||||
if k not in self.stats:
|
loss[k] = self.model.early_exit_loss( losses=v, t=self.training_steps )
|
||||||
self.stats[k] = []
|
|
||||||
self.stats[k].append( v )
|
|
||||||
|
|
||||||
for k, v in self.loss.items():
|
# to-do: instead make the cirriculum rely on samples processed instead of steps
|
||||||
self.loss[k] = self.model.early_exit_loss( losses=v, t=self.training_step )
|
self.training_steps += 1 # batch_size
|
||||||
# ick
|
|
||||||
self.training_step += 1
|
|
||||||
|
|
||||||
for k, v in self.stats.items():
|
|
||||||
self.stats[k] = sum( v ) / len( v )
|
|
||||||
|
|
||||||
# include any additional losses (for example: MoE router)
|
# include any additional losses (for example: MoE router)
|
||||||
if output.aux_loss is not None:
|
if output.aux_loss is not None:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user