the NAR only dream is dead (it just won't work)

This commit is contained in:
mrq 2024-06-12 19:49:47 -05:00
parent a9353cf9fa
commit bcf3910a17
3 changed files with 53 additions and 27 deletions

View File

@ -1,6 +1,7 @@
import argparse import argparse
from pathlib import Path from pathlib import Path
from .inference import TTS from .inference import TTS
from .config import cfg
def path_list(arg): def path_list(arg):
return [Path(p) for p in arg.split(";")] return [Path(p) for p in arg.split(";")]

View File

@ -141,11 +141,14 @@ class TTS():
sr = None sr = None
model_ar = None model_ar = None
model_len = None
model_nar = None model_nar = None
for name, engine in self.engines.items(): for name, engine in self.engines.items():
if "ar" in engine.hyper_config.capabilities: if "ar" in engine.hyper_config.capabilities:
model_ar = engine.module model_ar = engine.module
if "len" in engine.hyper_config.capabilities:
model_len = engine.module
if "nar" in engine.hyper_config.capabilities: if "nar" in engine.hyper_config.capabilities:
model_nar = engine.module model_nar = engine.module
@ -168,12 +171,7 @@ class TTS():
# AR temp: 1 # AR temp: 1
# NAR temp: 0.05 # NAR temp: 0.05
# prom size: 3 # prom size: 3
if model_ar is not None:
"""
resps_list = engine(text_list=text_list, proms_list=proms_list, max_steps=max_ar_steps, sampling_temperature=ar_temp)
resps_list = engine(text_list=text_list, proms_list=proms_list, resps_list=resps_list, sampling_temperature=nar_temp)
"""
resps_list = model_ar( resps_list = model_ar(
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context, text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context,
sampling_temperature=ar_temp, sampling_temperature=ar_temp,
@ -193,8 +191,17 @@ class TTS():
sampling_top_p=top_p, sampling_top_k=top_k, sampling_top_p=top_p, sampling_top_k=top_k,
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
) )
""" elif model_len is not None:
""" len_list = model_len( text_list=[phns], proms_list=[prom], max_steps=10 ) # don't need more than that
resps_list = model_nar( text_list=[phns], proms_list=[prom], len_list=len_list,
max_levels=max_nar_levels,
sampling_temperature=nar_temp,
sampling_min_temperature=min_nar_temp,
sampling_top_p=top_p, sampling_top_k=top_k,
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
)
else:
raise Exception("!")
wav, sr = qnt.decode_to_file(resps_list[0], out_path, device=self.device) wav, sr = qnt.decode_to_file(resps_list[0], out_path, device=self.device)
wavs.append(wav) wavs.append(wav)

View File

@ -229,7 +229,25 @@ class NAR(Base):
quant_levels=quant_levels, quant_levels=quant_levels,
) )
"""
resps_list = [ logit[-l:].argmax(dim=1) for logit, l in zip(logits, len_list) ] resps_list = [ logit[-l:].argmax(dim=1) for logit, l in zip(logits, len_list) ]
"""
resps_list = super().sample(
logits=logits,
resps_list=prev_list,
quant_levels=quant_levels,
temperature=1.0 if n == 0 else sampling_temperature,
min_temperature=sampling_min_temperature,
top_p=sampling_top_p,
top_k=sampling_top_k,
repetition_penalty=sampling_repetition_penalty,
repetition_penalty_decay=sampling_repetition_penalty_decay,
#length_penalty=sampling_length_penalty,
#beam_width=sampling_beam_width,
#mirostat=mirostat,
)
if n == 0: if n == 0:
prev_list = [ r.unsqueeze(-1).to(device) for r in resps_list ] prev_list = [ r.unsqueeze(-1).to(device) for r in resps_list ]