actually use langs from the dataloader

This commit is contained in:
mrq 2023-10-11 21:21:50 -05:00
parent 3af19d79fd
commit 08bae355eb
5 changed files with 11 additions and 3 deletions

View File

@ -102,6 +102,7 @@ class AR(Base):
text_list: list[Tensor],
proms_list: list[Tensor],
resps_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
max_steps: int = 1000,
max_resp_context: int = -1,
@ -128,6 +129,7 @@ class AR(Base):
proms_list=proms_list,
resps_list=self._unsqueeze_list(resps_list),
targ_list=resps_list,
lang_list=lang_list,
quant_levels=None,
)

View File

@ -134,6 +134,7 @@ class AR_NAR(Base):
proms_list=proms_list,
resps_list=resps_list,
targ_list=targ_list,
lang_list=lang_list,
quant_levels=quant_levels,
)
# is NAR
@ -153,6 +154,7 @@ class AR_NAR(Base):
text_list=text_list,
proms_list=proms_list,
resps_list=prev_list,
lang_list=lang_list,
quant_levels=quant_levels,
)
@ -204,7 +206,7 @@ class AR_NAR(Base):
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
state=recurrent_state
)

View File

@ -304,7 +304,7 @@ class Base(nn.Module):
batch_size = len(text_list)
if self.langs_emb is None:
langs_list = None
lang_list = None
x_list = self._samplewise_merge_tensors(
self.text_emb(text_list),

View File

@ -74,6 +74,7 @@ class NAR(Base):
text_list: list[Tensor],
proms_list: list[Tensor],
resps_list: list[Tensor],
lang_list: list[Tensor] | None = None,
max_levels: int = 0,
sampling_temperature: float = 0.2,
sampling_min_temperature: float = -1.0,
@ -118,6 +119,7 @@ class NAR(Base):
proms_list=proms_list,
resps_list=prev_list,
targ_list=targ_list,
lang_list=lang_list,
quant_levels=quant_levels,
)
@ -139,6 +141,7 @@ class NAR(Base):
text_list=text_list,
proms_list=proms_list,
resps_list=prev_list,
lang_list=lang_list,
quant_levels=quant_levels,
)

View File

@ -26,7 +26,8 @@ def train_feeder(engine, batch):
engine(
text_list=batch["text"],
proms_list=[prom[:, :engine._cfg.prom_levels] for prom in batch["proms"]], # reduce the input prompt to the target prom level
resps_list=batch["resps"]
resps_list=batch["resps"],
lang_list=batch["lang"],
)
losses = engine.gather_attribute("loss")