one more note

This commit is contained in:
mrq 2024-11-07 09:11:21 -06:00
parent 5698188824
commit d13ab00ad8

View File

@ -218,9 +218,11 @@ class NAR(Base):
# fill with mock tokens
#prev_list = [ torch.tensor([ self.stop_token for _ in range(resp_len) ], device=device, dtype=torch.int16) for resp_len in len_list ]
#prev_list = [ repeat_extend_audio( prom, resp_len ) for resp_len, prom in zip(len_list, proms_list) ]
#prev_list = [ None for resp_len in len_list ] # this breaks the position ID calc
prev_list = [ torch.concat([ self.dropout_token.unsqueeze(0) for _ in range( resp_len ) ]) for resp_len in len_list ]
#prev_list = [ None for resp_len in len_list ]
# to-do: special "scheduling" to inference RVQ-level 0
# to-do: figure out why this fails when I copy some things from ar_nar
for n in trange( max_levels, desc="NAR", disable=disable_tqdm ):