fixed the AR+NAR dual model, the resp_emb has to be split up (classifier might too)
This commit is contained in:
parent
100ca6b7d0
commit
7ce06432fd
|
@ -53,6 +53,10 @@ class AR(Base):
|
|||
return self.config.interleave
|
||||
return False
|
||||
|
||||
@property
|
||||
def dual(self) -> bool:
|
||||
return False
|
||||
|
||||
def _prune(self, l: Tensor):
|
||||
indices = (l == self.stop_token).nonzero()
|
||||
if len(indices) == 0:
|
||||
|
|
|
@ -16,7 +16,7 @@ class AR_NAR(Base):
|
|||
|
||||
@property
|
||||
def norm_type(self):
|
||||
return "ln"
|
||||
return "ln" if self.n_resp_levels == 1 else "adaln"
|
||||
|
||||
@property
|
||||
def arch_type(self) -> str:
|
||||
|
@ -44,15 +44,15 @@ class AR_NAR(Base):
|
|||
|
||||
@property
|
||||
def recurrent_chunk_size(self) -> int:
|
||||
if cfg.mode == "training":
|
||||
return 0
|
||||
return cfg.inference.recurrent_chunk_size
|
||||
return 0
|
||||
|
||||
@property
|
||||
def interleave(self) -> bool:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.interleave
|
||||
return False
|
||||
|
||||
@property
|
||||
def dual(self) -> bool:
|
||||
return True
|
||||
|
||||
def _prune(self, l: Tensor):
|
||||
indices = (l == self.stop_token).nonzero()
|
||||
|
@ -60,18 +60,6 @@ class AR_NAR(Base):
|
|||
return l
|
||||
return l[: indices.min().item()]
|
||||
|
||||
def _interleave( self, codes ):
|
||||
if not self.interleave:
|
||||
return codes
|
||||
|
||||
return codes.flatten()
|
||||
|
||||
def _deinterleave( self, codes, length = 0 ):
|
||||
if not self.interleave:
|
||||
return codes
|
||||
|
||||
return torch.unflatten( codes[:codes.shape[0] // self.n_prom_levels * self.n_prom_levels], 0, ( codes.shape[0] // self.n_prom_levels, self.n_prom_levels ) )
|
||||
|
||||
@staticmethod
|
||||
def _unsqueeze_list(x_list, axis=-1):
|
||||
return [x.unsqueeze(dim=axis) for x in x_list]
|
||||
|
@ -243,7 +231,7 @@ def example_usage():
|
|||
|
||||
def train():
|
||||
engine.train()
|
||||
t = trange(5000)
|
||||
t = trange(1000)
|
||||
for i in t:
|
||||
stats = {"step": i}
|
||||
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
|
||||
|
|
|
@ -126,6 +126,10 @@ class Base(nn.Module):
|
|||
def interleave(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def dual(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def stop_token(self):
|
||||
if not self.causal:
|
||||
|
@ -151,7 +155,7 @@ class Base(nn.Module):
|
|||
n_heads: int = 8,
|
||||
n_layers: int = 12,
|
||||
p_dropout: float = 0.1,
|
||||
|
||||
|
||||
config = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
@ -169,7 +173,11 @@ class Base(nn.Module):
|
|||
|
||||
self.text_emb = Embedding(n_tokens, d_model)
|
||||
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
||||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
||||
|
||||
if self.dual:
|
||||
self.resps_emb = nn.ModuleList([MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model) for _ in range(2)])
|
||||
else:
|
||||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
||||
|
||||
self.sep = nn.Parameter(torch.randn(d_model))
|
||||
|
||||
|
@ -254,12 +262,20 @@ class Base(nn.Module):
|
|||
|
||||
state: dict | None = None,
|
||||
):
|
||||
x_list = self._samplewise_merge_tensors(
|
||||
self.text_emb(text_list),
|
||||
self.proms_emb(proms_list),
|
||||
self.resps_emb(resps_list),
|
||||
sep=self.sep,
|
||||
)
|
||||
if self.dual:
|
||||
x_list = self._samplewise_merge_tensors(
|
||||
self.text_emb(text_list),
|
||||
self.proms_emb(proms_list),
|
||||
self.resps_emb[0 if quant_levels is None else 1](resps_list),
|
||||
sep=self.sep,
|
||||
)
|
||||
else:
|
||||
x_list = self._samplewise_merge_tensors(
|
||||
self.text_emb(text_list),
|
||||
self.proms_emb(proms_list),
|
||||
self.resps_emb(resps_list),
|
||||
sep=self.sep,
|
||||
)
|
||||
|
||||
x, m = list_to_tensor(x_list)
|
||||
|
||||
|
|
|
@ -47,6 +47,10 @@ class NAR(Base):
|
|||
def interleave(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def dual(self) -> bool:
|
||||
return False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text_list: list[Tensor],
|
||||
|
|
Loading…
Reference in New Issue
Block a user