oops
This commit is contained in:
parent
d12877ee09
commit
777ba43305
|
@ -106,7 +106,7 @@ class AR_NAR(Base):
|
||||||
|
|
||||||
# is training
|
# is training
|
||||||
if n_levels == self.n_resp_levels:
|
if n_levels == self.n_resp_levels:
|
||||||
if cfg.models.ar_nar.p_ar_level == "auto":
|
if cfg.models.ar_nar.p_ar_level == "auto" or cfg.models.ar_nar.p_ar_level is None:
|
||||||
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||||
else:
|
else:
|
||||||
quant_levels = torch.Tensor([ [ 0 if random.random() < cfg.models.ar_nar.p_ar_level else random.randint(1, self.n_resp_levels) ] for _ in range(batch_size) ])
|
quant_levels = torch.Tensor([ [ 0 if random.random() < cfg.models.ar_nar.p_ar_level else random.randint(1, self.n_resp_levels) ] for _ in range(batch_size) ])
|
||||||
|
|
Loading…
Reference in New Issue
Block a user