diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 53a3793..94201e5 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -106,7 +106,7 @@ class AR_NAR(Base): # is training if n_levels == self.n_resp_levels: - if cfg.models.ar_nar.p_ar_level == "auto": + if cfg.models.ar_nar.p_ar_level == "auto" or cfg.models.ar_nar.p_ar_level is None: quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) else: quant_levels = torch.Tensor([ [ 0 if random.random() < cfg.models.ar_nar.p_ar_level else random.randint(1, self.n_resp_levels) ] for _ in range(batch_size) ])