Hardcode prom levels
This commit is contained in:
parent
49250b3c17
commit
ae029c1d75
|
@ -231,6 +231,10 @@ class Base(nn.Module):
|
||||||
def use_stop_token(self) -> bool:
|
def use_stop_token(self) -> bool:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_prom_levels(self) -> int:
|
||||||
|
return 8
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
n_tokens: int,
|
n_tokens: int,
|
||||||
|
@ -238,7 +242,6 @@ class Base(nn.Module):
|
||||||
n_heads: int = 8,
|
n_heads: int = 8,
|
||||||
n_layers: int = 12,
|
n_layers: int = 12,
|
||||||
p_dropout: float = 0.1,
|
p_dropout: float = 0.1,
|
||||||
n_prom_levels: int = 8,
|
|
||||||
resp_loss_only: bool = False,
|
resp_loss_only: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -254,7 +257,7 @@ class Base(nn.Module):
|
||||||
|
|
||||||
# It's not clear whether the whole prom are used or only the first level quantization
|
# It's not clear whether the whole prom are used or only the first level quantization
|
||||||
# Just use all of them as it is more sufficient and we don't need to sample it, or do we?
|
# Just use all of them as it is more sufficient and we don't need to sample it, or do we?
|
||||||
self.prom_emb = MultiEmbedding(n_tokens, d_model, n_levels=n_prom_levels)
|
self.prom_emb = MultiEmbedding(n_tokens, d_model, n_levels=self.n_prom_levels)
|
||||||
|
|
||||||
# +1 to include the stop token
|
# +1 to include the stop token
|
||||||
# Note that, for different levels, I don't use AdaLN for simplicity
|
# Note that, for different levels, I don't use AdaLN for simplicity
|
||||||
|
|
Loading…
Reference in New Issue
Block a user