actually use langs from the dataloader
This commit is contained in:
parent
3af19d79fd
commit
08bae355eb
|
@ -102,6 +102,7 @@ class AR(Base):
|
|||
text_list: list[Tensor],
|
||||
proms_list: list[Tensor],
|
||||
resps_list: list[Tensor] | None = None,
|
||||
lang_list: list[Tensor] | None = None,
|
||||
max_steps: int = 1000,
|
||||
max_resp_context: int = -1,
|
||||
|
||||
|
@ -128,6 +129,7 @@ class AR(Base):
|
|||
proms_list=proms_list,
|
||||
resps_list=self._unsqueeze_list(resps_list),
|
||||
targ_list=resps_list,
|
||||
lang_list=lang_list,
|
||||
quant_levels=None,
|
||||
)
|
||||
|
||||
|
|
|
@ -134,6 +134,7 @@ class AR_NAR(Base):
|
|||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
targ_list=targ_list,
|
||||
lang_list=lang_list,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
# is NAR
|
||||
|
@ -153,6 +154,7 @@ class AR_NAR(Base):
|
|||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=prev_list,
|
||||
lang_list=lang_list,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
|
@ -204,7 +206,7 @@ class AR_NAR(Base):
|
|||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
|
||||
lang_list=lang_list,
|
||||
state=recurrent_state
|
||||
)
|
||||
|
||||
|
|
|
@ -304,7 +304,7 @@ class Base(nn.Module):
|
|||
batch_size = len(text_list)
|
||||
|
||||
if self.langs_emb is None:
|
||||
langs_list = None
|
||||
lang_list = None
|
||||
|
||||
x_list = self._samplewise_merge_tensors(
|
||||
self.text_emb(text_list),
|
||||
|
|
|
@ -74,6 +74,7 @@ class NAR(Base):
|
|||
text_list: list[Tensor],
|
||||
proms_list: list[Tensor],
|
||||
resps_list: list[Tensor],
|
||||
lang_list: list[Tensor] | None = None,
|
||||
max_levels: int = 0,
|
||||
sampling_temperature: float = 0.2,
|
||||
sampling_min_temperature: float = -1.0,
|
||||
|
@ -118,6 +119,7 @@ class NAR(Base):
|
|||
proms_list=proms_list,
|
||||
resps_list=prev_list,
|
||||
targ_list=targ_list,
|
||||
lang_list=lang_list,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
|
@ -139,6 +141,7 @@ class NAR(Base):
|
|||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=prev_list,
|
||||
lang_list=lang_list,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
|
|
|
@ -26,7 +26,8 @@ def train_feeder(engine, batch):
|
|||
engine(
|
||||
text_list=batch["text"],
|
||||
proms_list=[prom[:, :engine._cfg.prom_levels] for prom in batch["proms"]], # reduce the input prompt to the target prom level
|
||||
resps_list=batch["resps"]
|
||||
resps_list=batch["resps"],
|
||||
lang_list=batch["lang"],
|
||||
)
|
||||
|
||||
losses = engine.gather_attribute("loss")
|
||||
|
|
Loading…
Reference in New Issue
Block a user