From 777ba43305461305a8eb24ef3a4a05cfb544992d Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 3 Oct 2023 15:01:37 -0500 Subject: [PATCH] oops --- vall_e/models/ar_nar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 53a3793..94201e5 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -106,7 +106,7 @@ class AR_NAR(Base): # is training if n_levels == self.n_resp_levels: - if cfg.models.ar_nar.p_ar_level == "auto": + if cfg.models.ar_nar.p_ar_level == "auto" or cfg.models.ar_nar.p_ar_level is None: quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) else: quant_levels = torch.Tensor([ [ 0 if random.random() < cfg.models.ar_nar.p_ar_level else random.randint(1, self.n_resp_levels) ] for _ in range(batch_size) ])