repeat extend the prom to fill the initial tokens for nar-len (it somewhat works, the model just needs to train more)

This commit is contained in:
mrq 2024-11-06 23:29:53 -06:00
parent a3bc26f7ec
commit 77ff23e319
2 changed files with 7 additions and 3 deletions

View File

@ -356,6 +356,7 @@ class TTS():
)
elif model_len is not None:
len_list = model_len( text_list=[phns], proms_list=[prom], max_steps=10, disable_tqdm=not tqdm ) # don't need more than that
len_list = [ min(l, max_ar_steps) for l in len_list ]
resps_list = model_nar( text_list=[phns], proms_list=[prom], len_list=len_list,
max_levels=max_nar_levels,
sampling_temperature=nar_temp,

View File

@ -18,7 +18,8 @@ from einops import rearrange
from torch import Tensor
from tqdm import trange
from ..emb.qnt import trim
from ..emb.qnt import trim, repeat_extend_audio
import logging
def clamp(n, lo, hi):
@ -216,7 +217,8 @@ class NAR(Base):
# fill with mock tokens
# to-do: repeat with the input prompt, as per training
prev_list = [ torch.tensor([ self.stop_token for _ in range(resp_len) ], device=device, dtype=torch.int16) for resp_len in len_list ]
#prev_list = [ torch.tensor([ self.stop_token for _ in range(resp_len) ], device=device, dtype=torch.int16) for resp_len in len_list ]
prev_list = [ repeat_extend_audio( prom, resp_len ) for resp_len, prom in zip(len_list, proms_list) ]
# to-do: figure out why this fails when I copy some things from ar_nar
for n in trange( max_levels, desc="NAR", disable=disable_tqdm ):
@ -251,7 +253,8 @@ class NAR(Base):
prev_list=prev_list,
quant_levels=quant_levels,
temperature=sampling_temperature,
#temperature=sampling_temperature,
temperature=1.0 if n == 0 else sampling_temperature,
min_temperature=sampling_min_temperature,
top_p=sampling_top_p,
top_k=sampling_top_k,