This commit is contained in:
mrq 2024-10-30 22:49:11 -05:00
parent a22534e8f4
commit b63293cbbe
2 changed files with 18 additions and 8 deletions

View File

@ -322,15 +322,23 @@ class LlamaModel_Adapted(LlamaModel):
return random.random() < P
# to-do: properly implement this per the paper
# this probably is a function of layer number and training step to decide what layer to apply layerskip to for training
def cirriculum( self, l, t=0 ):
return 1 # self.layers_n - 1
# there doesn't seem /too/ bad of a performance hit, but the paper mentions it affecting accuracy of the last layer if all layers had early exit
def cirriculum( self, l, t=None, R=2 ):
if t is None:
return 1
# YUCK
# this guarantees at least R layers are active at all intervals, which is important because this gives a division by zero otherwise
for i in range(R):
if l == ((t % self.layers_n) + i * (self.layers_n // R)) % self.layers_n:
return 1
return 0
def early_exit_loss( self, losses, t=0 ):
def early_exit_loss( self, losses, t=None ):
return sum([ self.normalized_per_layer_loss_scale( l, t ) * losses[l] for l in range(0, self.layers_n) ])
def normalized_per_layer_loss_scale( self, l, t=0 ):
return (self.cirriculum(l, t) * self.early_exit_factor( l )) / (sum([ self.cirriculum(l, t) * self.early_exit_factor( i ) for i in range(0, self.layers_n) ]))
def normalized_per_layer_loss_scale( self, l, t=None ):
return (self.cirriculum(l, t) * self.early_exit_factor( l )) / sum([ self.cirriculum(i, t) * self.early_exit_factor( i ) for i in range(0, self.layers_n) ])
def early_exit_factor( self, l ):
if 0 <= l and l < self.layers_n:

View File

@ -472,6 +472,7 @@ class Base(nn.Module):
self.unified_position_ids = unified_position_ids
self.interleave = interleave
self.layerskip = layerskip
self.training_step = 0
self.text_emb = Embedding(n_text_tokens, d_model)
self.langs_emb = None
@ -1506,12 +1507,13 @@ class Base(nn.Module):
self.stats[k].append( v )
for k, v in self.loss.items():
self.loss[k] = self.model.early_exit_loss( losses=v )
self.loss[k] = self.model.early_exit_loss( losses=v, t=self.training_step )
# ick
self.training_step += 1
for k, v in self.stats.items():
self.stats[k] = sum( v ) / len( v )
# include any additional losses (for example: MoE router)
if output.aux_loss is not None:
loss["aux_loss"] = output.aux_loss