Hardcode prom levels
This commit is contained in:
parent
49250b3c17
commit
ae029c1d75
|
@ -231,6 +231,10 @@ class Base(nn.Module):
|
|||
def use_stop_token(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def n_prom_levels(self) -> int:
|
||||
return 8
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_tokens: int,
|
||||
|
@ -238,7 +242,6 @@ class Base(nn.Module):
|
|||
n_heads: int = 8,
|
||||
n_layers: int = 12,
|
||||
p_dropout: float = 0.1,
|
||||
n_prom_levels: int = 8,
|
||||
resp_loss_only: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
@ -254,7 +257,7 @@ class Base(nn.Module):
|
|||
|
||||
# It's not clear whether the whole prom are used or only the first level quantization
|
||||
# Just use all of them as it is more sufficient and we don't need to sample it, or do we?
|
||||
self.prom_emb = MultiEmbedding(n_tokens, d_model, n_levels=n_prom_levels)
|
||||
self.prom_emb = MultiEmbedding(n_tokens, d_model, n_levels=self.n_prom_levels)
|
||||
|
||||
# +1 to include the stop token
|
||||
# Note that, for different levels, I don't use AdaLN for simplicity
|
||||
|
|
Loading…
Reference in New Issue
Block a user