diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index 81fb5bd..3da812d 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -322,15 +322,23 @@ class LlamaModel_Adapted(LlamaModel): return random.random() < P # to-do: properly implement this per the paper - # this probably is a function of layer number and training step to decide what layer to apply layerskip to for training - def cirriculum( self, l, t=0 ): - return 1 # self.layers_n - 1 + # there doesn't seem /too/ bad of a performance hit, but the paper mentions it affecting accuracy of the last layer if all layers had early exit + def cirriculum( self, l, t=None, R=2 ): + if t is None: + return 1 + + # YUCK + # this guarantees at least R layers are active at all intervals, which is important because this gives a division by zero otherwise + for i in range(R): + if l == ((t % self.layers_n) + i * (self.layers_n // R)) % self.layers_n: + return 1 + return 0 - def early_exit_loss( self, losses, t=0 ): + def early_exit_loss( self, losses, t=None ): return sum([ self.normalized_per_layer_loss_scale( l, t ) * losses[l] for l in range(0, self.layers_n) ]) - def normalized_per_layer_loss_scale( self, l, t=0 ): - return (self.cirriculum(l, t) * self.early_exit_factor( l )) / (sum([ self.cirriculum(l, t) * self.early_exit_factor( i ) for i in range(0, self.layers_n) ])) + def normalized_per_layer_loss_scale( self, l, t=None ): + return (self.cirriculum(l, t) * self.early_exit_factor( l )) / sum([ self.cirriculum(i, t) * self.early_exit_factor( i ) for i in range(0, self.layers_n) ]) def early_exit_factor( self, l ): if 0 <= l and l < self.layers_n: diff --git a/vall_e/models/base.py b/vall_e/models/base.py index d2ea128..09cea8c 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -472,6 +472,7 @@ class Base(nn.Module): self.unified_position_ids = unified_position_ids self.interleave = interleave self.layerskip = layerskip + self.training_step = 0 self.text_emb = Embedding(n_text_tokens, d_model) self.langs_emb = None @@ -1506,12 +1507,13 @@ class Base(nn.Module): self.stats[k].append( v ) for k, v in self.loss.items(): - self.loss[k] = self.model.early_exit_loss( losses=v ) + self.loss[k] = self.model.early_exit_loss( losses=v, t=self.training_step ) + # ick + self.training_step += 1 for k, v in self.stats.items(): self.stats[k] = sum( v ) / len( v ) - # include any additional losses (for example: MoE router) if output.aux_loss is not None: loss["aux_loss"] = output.aux_loss