Hardcode prom levels

This commit is contained in:
enhuiz 2023-01-12 19:44:59 +08:00
parent 49250b3c17
commit ae029c1d75

View File

@ -231,6 +231,10 @@ class Base(nn.Module):
def use_stop_token(self) -> bool:
raise NotImplementedError
@property
def n_prom_levels(self) -> int:
return 8
def __init__(
self,
n_tokens: int,
@ -238,7 +242,6 @@ class Base(nn.Module):
n_heads: int = 8,
n_layers: int = 12,
p_dropout: float = 0.1,
n_prom_levels: int = 8,
resp_loss_only: bool = False,
):
super().__init__()
@ -254,7 +257,7 @@ class Base(nn.Module):
# It's not clear whether the whole prom are used or only the first level quantization
# Just use all of them as it is more sufficient and we don't need to sample it, or do we?
self.prom_emb = MultiEmbedding(n_tokens, d_model, n_levels=n_prom_levels)
self.prom_emb = MultiEmbedding(n_tokens, d_model, n_levels=self.n_prom_levels)
# +1 to include the stop token
# Note that, for different levels, I don't use AdaLN for simplicity