diff --git a/vall_e/vall_e/base.py b/vall_e/vall_e/base.py index da706b2..bd9fe36 100644 --- a/vall_e/vall_e/base.py +++ b/vall_e/vall_e/base.py @@ -231,6 +231,10 @@ class Base(nn.Module): def use_stop_token(self) -> bool: raise NotImplementedError + @property + def n_prom_levels(self) -> int: + return 8 + def __init__( self, n_tokens: int, @@ -238,7 +242,6 @@ class Base(nn.Module): n_heads: int = 8, n_layers: int = 12, p_dropout: float = 0.1, - n_prom_levels: int = 8, resp_loss_only: bool = False, ): super().__init__() @@ -254,7 +257,7 @@ class Base(nn.Module): # It's not clear whether the whole prom are used or only the first level quantization # Just use all of them as it is more sufficient and we don't need to sample it, or do we? - self.prom_emb = MultiEmbedding(n_tokens, d_model, n_levels=n_prom_levels) + self.prom_emb = MultiEmbedding(n_tokens, d_model, n_levels=self.n_prom_levels) # +1 to include the stop token # Note that, for different levels, I don't use AdaLN for simplicity