ugh
This commit is contained in:
parent
a22534e8f4
commit
b63293cbbe
|
@ -322,15 +322,23 @@ class LlamaModel_Adapted(LlamaModel):
|
|||
return random.random() < P
|
||||
|
||||
# to-do: properly implement this per the paper
|
||||
# this probably is a function of layer number and training step to decide what layer to apply layerskip to for training
|
||||
def cirriculum( self, l, t=0 ):
|
||||
return 1 # self.layers_n - 1
|
||||
# there doesn't seem /too/ bad of a performance hit, but the paper mentions it affecting accuracy of the last layer if all layers had early exit
|
||||
def cirriculum( self, l, t=None, R=2 ):
|
||||
if t is None:
|
||||
return 1
|
||||
|
||||
# YUCK
|
||||
# this guarantees at least R layers are active at all intervals, which is important because this gives a division by zero otherwise
|
||||
for i in range(R):
|
||||
if l == ((t % self.layers_n) + i * (self.layers_n // R)) % self.layers_n:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def early_exit_loss( self, losses, t=0 ):
|
||||
def early_exit_loss( self, losses, t=None ):
|
||||
return sum([ self.normalized_per_layer_loss_scale( l, t ) * losses[l] for l in range(0, self.layers_n) ])
|
||||
|
||||
def normalized_per_layer_loss_scale( self, l, t=0 ):
|
||||
return (self.cirriculum(l, t) * self.early_exit_factor( l )) / (sum([ self.cirriculum(l, t) * self.early_exit_factor( i ) for i in range(0, self.layers_n) ]))
|
||||
def normalized_per_layer_loss_scale( self, l, t=None ):
|
||||
return (self.cirriculum(l, t) * self.early_exit_factor( l )) / sum([ self.cirriculum(i, t) * self.early_exit_factor( i ) for i in range(0, self.layers_n) ])
|
||||
|
||||
def early_exit_factor( self, l ):
|
||||
if 0 <= l and l < self.layers_n:
|
||||
|
|
|
@ -472,6 +472,7 @@ class Base(nn.Module):
|
|||
self.unified_position_ids = unified_position_ids
|
||||
self.interleave = interleave
|
||||
self.layerskip = layerskip
|
||||
self.training_step = 0
|
||||
|
||||
self.text_emb = Embedding(n_text_tokens, d_model)
|
||||
self.langs_emb = None
|
||||
|
@ -1506,12 +1507,13 @@ class Base(nn.Module):
|
|||
self.stats[k].append( v )
|
||||
|
||||
for k, v in self.loss.items():
|
||||
self.loss[k] = self.model.early_exit_loss( losses=v )
|
||||
self.loss[k] = self.model.early_exit_loss( losses=v, t=self.training_step )
|
||||
# ick
|
||||
self.training_step += 1
|
||||
|
||||
for k, v in self.stats.items():
|
||||
self.stats[k] = sum( v ) / len( v )
|
||||
|
||||
|
||||
# include any additional losses (for example: MoE router)
|
||||
if output.aux_loss is not None:
|
||||
loss["aux_loss"] = output.aux_loss
|
||||
|
|
Loading…
Reference in New Issue
Block a user