This commit is contained in:
mrq 2024-06-12 00:14:29 -05:00
parent cca542a4c0
commit a9353cf9fa

View File

@ -173,8 +173,11 @@ class NAR(Base):
quant_levels = [ 0 if task_list[i] == "len" else random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
# clamp quant_levels because some of my audio was saved for only 8 out of 9 RVQ levels for DAC...
for i, prom in enumerate(proms_list):
if quant_levels[i] + 1 > prom.shape[-1]:
quant_levels[i] = prom.shape[-1] - 1
for i, resp in enumerate(resps_list):
if quant_levels[i] >= resp.shape[-1]:
if quant_levels[i] + 1 > resp.shape[-1]:
quant_levels[i] = resp.shape[-1] - 1
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]