oops
This commit is contained in:
parent
d12877ee09
commit
777ba43305
|
@ -106,7 +106,7 @@ class AR_NAR(Base):
|
|||
|
||||
# is training
|
||||
if n_levels == self.n_resp_levels:
|
||||
if cfg.models.ar_nar.p_ar_level == "auto":
|
||||
if cfg.models.ar_nar.p_ar_level == "auto" or cfg.models.ar_nar.p_ar_level is None:
|
||||
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||
else:
|
||||
quant_levels = torch.Tensor([ [ 0 if random.random() < cfg.models.ar_nar.p_ar_level else random.randint(1, self.n_resp_levels) ] for _ in range(batch_size) ])
|
||||
|
|
Loading…
Reference in New Issue
Block a user