fixed the AR+NAR dual model, the resp_emb has to be split up (classifier might too)

This commit is contained in:
mrq 2023-09-06 19:33:39 -05:00
parent 100ca6b7d0
commit 7ce06432fd
4 changed files with 39 additions and 27 deletions

View File

@ -53,6 +53,10 @@ class AR(Base):
return self.config.interleave
return False
@property
def dual(self) -> bool:
return False
def _prune(self, l: Tensor):
indices = (l == self.stop_token).nonzero()
if len(indices) == 0:

View File

@ -16,7 +16,7 @@ class AR_NAR(Base):
@property
def norm_type(self):
return "ln"
return "ln" if self.n_resp_levels == 1 else "adaln"
@property
def arch_type(self) -> str:
@ -44,15 +44,15 @@ class AR_NAR(Base):
@property
def recurrent_chunk_size(self) -> int:
if cfg.mode == "training":
return 0
return cfg.inference.recurrent_chunk_size
return 0
@property
def interleave(self) -> bool:
if hasattr(self, "config") and self.config:
return self.config.interleave
return False
@property
def dual(self) -> bool:
return True
def _prune(self, l: Tensor):
indices = (l == self.stop_token).nonzero()
@ -60,18 +60,6 @@ class AR_NAR(Base):
return l
return l[: indices.min().item()]
def _interleave( self, codes ):
if not self.interleave:
return codes
return codes.flatten()
def _deinterleave( self, codes, length = 0 ):
if not self.interleave:
return codes
return torch.unflatten( codes[:codes.shape[0] // self.n_prom_levels * self.n_prom_levels], 0, ( codes.shape[0] // self.n_prom_levels, self.n_prom_levels ) )
@staticmethod
def _unsqueeze_list(x_list, axis=-1):
return [x.unsqueeze(dim=axis) for x in x_list]
@ -243,7 +231,7 @@ def example_usage():
def train():
engine.train()
t = trange(5000)
t = trange(1000)
for i in t:
stats = {"step": i}
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)

View File

@ -126,6 +126,10 @@ class Base(nn.Module):
def interleave(self) -> bool:
return False
@property
def dual(self) -> bool:
return False
@property
def stop_token(self):
if not self.causal:
@ -151,7 +155,7 @@ class Base(nn.Module):
n_heads: int = 8,
n_layers: int = 12,
p_dropout: float = 0.1,
config = None,
):
super().__init__()
@ -169,7 +173,11 @@ class Base(nn.Module):
self.text_emb = Embedding(n_tokens, d_model)
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
if self.dual:
self.resps_emb = nn.ModuleList([MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model) for _ in range(2)])
else:
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
self.sep = nn.Parameter(torch.randn(d_model))
@ -254,12 +262,20 @@ class Base(nn.Module):
state: dict | None = None,
):
x_list = self._samplewise_merge_tensors(
self.text_emb(text_list),
self.proms_emb(proms_list),
self.resps_emb(resps_list),
sep=self.sep,
)
if self.dual:
x_list = self._samplewise_merge_tensors(
self.text_emb(text_list),
self.proms_emb(proms_list),
self.resps_emb[0 if quant_levels is None else 1](resps_list),
sep=self.sep,
)
else:
x_list = self._samplewise_merge_tensors(
self.text_emb(text_list),
self.proms_emb(proms_list),
self.resps_emb(resps_list),
sep=self.sep,
)
x, m = list_to_tensor(x_list)

View File

@ -47,6 +47,10 @@ class NAR(Base):
def interleave(self) -> bool:
return False
@property
def dual(self) -> bool:
return False
def forward(
self,
text_list: list[Tensor],