diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 7e0ecc5..466537f 100644 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -173,8 +173,11 @@ class NAR(Base): quant_levels = [ 0 if task_list[i] == "len" else random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ] # clamp quant_levels because some of my audio was saved for only 8 out of 9 RVQ levels for DAC... + for i, prom in enumerate(proms_list): + if quant_levels[i] + 1 > prom.shape[-1]: + quant_levels[i] = prom.shape[-1] - 1 for i, resp in enumerate(resps_list): - if quant_levels[i] >= resp.shape[-1]: + if quant_levels[i] + 1 > resp.shape[-1]: quant_levels[i] = resp.shape[-1] - 1 resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]