the NAR only dream is dead (it just won't work)

This commit is contained in:
mrq 2024-06-12 19:49:47 -05:00
parent a9353cf9fa
commit bcf3910a17
3 changed files with 53 additions and 27 deletions

View File

@ -1,6 +1,7 @@
import argparse
from pathlib import Path
from .inference import TTS
from .config import cfg
def path_list(arg):
return [Path(p) for p in arg.split(";")]

View File

@ -141,11 +141,14 @@ class TTS():
sr = None
model_ar = None
model_len = None
model_nar = None
for name, engine in self.engines.items():
if "ar" in engine.hyper_config.capabilities:
model_ar = engine.module
if "len" in engine.hyper_config.capabilities:
model_len = engine.module
if "nar" in engine.hyper_config.capabilities:
model_nar = engine.module
@ -168,33 +171,37 @@ class TTS():
# AR temp: 1
# NAR temp: 0.05
# prom size: 3
"""
resps_list = engine(text_list=text_list, proms_list=proms_list, max_steps=max_ar_steps, sampling_temperature=ar_temp)
resps_list = engine(text_list=text_list, proms_list=proms_list, resps_list=resps_list, sampling_temperature=nar_temp)
"""
resps_list = model_ar(
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context,
sampling_temperature=ar_temp,
sampling_min_temperature=min_ar_temp,
sampling_top_p=top_p, sampling_top_k=top_k,
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
sampling_length_penalty=length_penalty,
sampling_beam_width=beam_width,
sampling_mirostat_tau=mirostat_tau,
sampling_mirostat_eta=mirostat_eta,
)
resps_list = model_nar(
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list,
max_levels=max_nar_levels,
sampling_temperature=nar_temp,
sampling_min_temperature=min_nar_temp,
sampling_top_p=top_p, sampling_top_k=top_k,
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
)
"""
"""
if model_ar is not None:
resps_list = model_ar(
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context,
sampling_temperature=ar_temp,
sampling_min_temperature=min_ar_temp,
sampling_top_p=top_p, sampling_top_k=top_k,
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
sampling_length_penalty=length_penalty,
sampling_beam_width=beam_width,
sampling_mirostat_tau=mirostat_tau,
sampling_mirostat_eta=mirostat_eta,
)
resps_list = model_nar(
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list,
max_levels=max_nar_levels,
sampling_temperature=nar_temp,
sampling_min_temperature=min_nar_temp,
sampling_top_p=top_p, sampling_top_k=top_k,
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
)
elif model_len is not None:
len_list = model_len( text_list=[phns], proms_list=[prom], max_steps=10 ) # don't need more than that
resps_list = model_nar( text_list=[phns], proms_list=[prom], len_list=len_list,
max_levels=max_nar_levels,
sampling_temperature=nar_temp,
sampling_min_temperature=min_nar_temp,
sampling_top_p=top_p, sampling_top_k=top_k,
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
)
else:
raise Exception("!")
wav, sr = qnt.decode_to_file(resps_list[0], out_path, device=self.device)
wavs.append(wav)

View File

@ -229,7 +229,25 @@ class NAR(Base):
quant_levels=quant_levels,
)
"""
resps_list = [ logit[-l:].argmax(dim=1) for logit, l in zip(logits, len_list) ]
"""
resps_list = super().sample(
logits=logits,
resps_list=prev_list,
quant_levels=quant_levels,
temperature=1.0 if n == 0 else sampling_temperature,
min_temperature=sampling_min_temperature,
top_p=sampling_top_p,
top_k=sampling_top_k,
repetition_penalty=sampling_repetition_penalty,
repetition_penalty_decay=sampling_repetition_penalty_decay,
#length_penalty=sampling_length_penalty,
#beam_width=sampling_beam_width,
#mirostat=mirostat,
)
if n == 0:
prev_list = [ r.unsqueeze(-1).to(device) for r in resps_list ]