Use different sampling temperature for AR and NAR
This commit is contained in:
parent
77b52e42ce
commit
8188506440
@ -5,3 +5,5 @@ model: ar
|
||||
batch_size: 24
|
||||
eval_batch_size: 24
|
||||
eval_every: 10_000
|
||||
|
||||
sampling_temperature: 1.0
|
||||
|
@ -5,3 +5,5 @@ model: nar
|
||||
batch_size: 24
|
||||
eval_batch_size: 24
|
||||
eval_every: 1_000
|
||||
|
||||
sampling_temperature: 0.2
|
||||
|
@ -46,6 +46,7 @@ class Config(ConfigBase):
|
||||
|
||||
use_fp16: bool = True
|
||||
gradient_accumulation_steps: int = 1
|
||||
sampling_temperature: float = 1.0
|
||||
|
||||
@cached_property
|
||||
def get_spkr(self):
|
||||
|
@ -76,6 +76,7 @@ def main():
|
||||
text_list=batch["text"],
|
||||
proms_list=batch["proms"],
|
||||
max_steps=cfg.max_val_ar_steps,
|
||||
sampling_temperature=cfg.sampling_temperature,
|
||||
)
|
||||
resps_list = [r.unsqueeze(-1) for r in resp_list]
|
||||
elif cfg.model.startswith("nar"):
|
||||
@ -83,6 +84,7 @@ def main():
|
||||
text_list=batch["text"],
|
||||
proms_list=batch["proms"],
|
||||
resp_list=batch["resp"],
|
||||
sampling_temperature=cfg.sampling_temperature,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(cfg.model)
|
||||
|
@ -39,6 +39,7 @@ class AR(Base):
|
||||
proms_list: list[Tensor],
|
||||
resp_list: list[Tensor] | None = None,
|
||||
max_steps: int = 1000,
|
||||
sampling_temperature: float = 1.0,
|
||||
):
|
||||
if resp_list is not None:
|
||||
return super().forward(
|
||||
@ -51,13 +52,19 @@ class AR(Base):
|
||||
return_all_resp=False,
|
||||
)
|
||||
else:
|
||||
return self._generate(text_list, proms_list, max_steps)
|
||||
return self._generate(
|
||||
text_list,
|
||||
proms_list,
|
||||
max_steps,
|
||||
sampling_temperature,
|
||||
)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
text_list: list[Tensor],
|
||||
proms_list: list[Tensor],
|
||||
max_steps: int,
|
||||
sampling_temperature: float,
|
||||
):
|
||||
device = text_list[0].device
|
||||
resp_list: list[Tensor] = [
|
||||
@ -65,7 +72,12 @@ class AR(Base):
|
||||
]
|
||||
stopped = torch.zeros(len(text_list), device=device).bool()
|
||||
for _ in trange(max_steps):
|
||||
r = super().forward(text_list, proms_list, resp_list)
|
||||
r = super().forward(
|
||||
text_list,
|
||||
proms_list,
|
||||
resp_list,
|
||||
sampling_temperature=sampling_temperature,
|
||||
)
|
||||
stopped |= r == self.stop_token
|
||||
for i, ri in enumerate(r):
|
||||
resp_list[i] = torch.cat([resp_list[i], ri[None]])
|
||||
|
@ -407,6 +407,7 @@ class Base(nn.Module):
|
||||
quant_levels: Tensor | None = None,
|
||||
shift_targ_list: bool = False,
|
||||
return_all_resp: Literal[False] = False,
|
||||
sampling_temperature: float = 1.0,
|
||||
) -> Tensor:
|
||||
...
|
||||
|
||||
@ -420,6 +421,7 @@ class Base(nn.Module):
|
||||
quant_levels: Tensor | None = None,
|
||||
shift_targ_list: bool = False,
|
||||
return_all_resp: Literal[True] = True,
|
||||
sampling_temperature: float = 1.0,
|
||||
) -> list[Tensor]:
|
||||
...
|
||||
|
||||
@ -432,7 +434,7 @@ class Base(nn.Module):
|
||||
quant_levels: Tensor | None = None,
|
||||
shift_targ_list: bool = False,
|
||||
return_all_resp: bool = False,
|
||||
sampling_temperature: float = 0.2,
|
||||
sampling_temperature: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -33,6 +33,7 @@ class NAR(Base):
|
||||
*,
|
||||
resp_list: list[Tensor] | None = None,
|
||||
resps_list: list[Tensor] | None = None,
|
||||
sampling_temperature: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -90,6 +91,7 @@ class NAR(Base):
|
||||
return_all_resp=True,
|
||||
shift_targ_list=False,
|
||||
quant_levels=quant_levels,
|
||||
sampling_temperature=sampling_temperature,
|
||||
)
|
||||
hyp_resp_lists.append(hyp_resp_list)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user