This commit is contained in:
mrq 2023-10-03 15:01:37 -05:00
parent d12877ee09
commit 777ba43305

View File

@ -106,7 +106,7 @@ class AR_NAR(Base):
# is training # is training
if n_levels == self.n_resp_levels: if n_levels == self.n_resp_levels:
if cfg.models.ar_nar.p_ar_level == "auto": if cfg.models.ar_nar.p_ar_level == "auto" or cfg.models.ar_nar.p_ar_level is None:
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
else: else:
quant_levels = torch.Tensor([ [ 0 if random.random() < cfg.models.ar_nar.p_ar_level else random.randint(1, self.n_resp_levels) ] for _ in range(batch_size) ]) quant_levels = torch.Tensor([ [ 0 if random.random() < cfg.models.ar_nar.p_ar_level else random.randint(1, self.n_resp_levels) ] for _ in range(batch_size) ])