the NAR only dream is dead (it just won't work)
This commit is contained in:
parent
a9353cf9fa
commit
bcf3910a17
|
@ -1,6 +1,7 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
from .inference import TTS
|
||||
from .config import cfg
|
||||
|
||||
def path_list(arg):
|
||||
return [Path(p) for p in arg.split(";")]
|
||||
|
|
|
@ -141,11 +141,14 @@ class TTS():
|
|||
sr = None
|
||||
|
||||
model_ar = None
|
||||
model_len = None
|
||||
model_nar = None
|
||||
|
||||
for name, engine in self.engines.items():
|
||||
if "ar" in engine.hyper_config.capabilities:
|
||||
model_ar = engine.module
|
||||
if "len" in engine.hyper_config.capabilities:
|
||||
model_len = engine.module
|
||||
if "nar" in engine.hyper_config.capabilities:
|
||||
model_nar = engine.module
|
||||
|
||||
|
@ -168,33 +171,37 @@ class TTS():
|
|||
# AR temp: 1
|
||||
# NAR temp: 0.05
|
||||
# prom size: 3
|
||||
|
||||
"""
|
||||
resps_list = engine(text_list=text_list, proms_list=proms_list, max_steps=max_ar_steps, sampling_temperature=ar_temp)
|
||||
resps_list = engine(text_list=text_list, proms_list=proms_list, resps_list=resps_list, sampling_temperature=nar_temp)
|
||||
"""
|
||||
|
||||
resps_list = model_ar(
|
||||
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context,
|
||||
sampling_temperature=ar_temp,
|
||||
sampling_min_temperature=min_ar_temp,
|
||||
sampling_top_p=top_p, sampling_top_k=top_k,
|
||||
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
||||
sampling_length_penalty=length_penalty,
|
||||
sampling_beam_width=beam_width,
|
||||
sampling_mirostat_tau=mirostat_tau,
|
||||
sampling_mirostat_eta=mirostat_eta,
|
||||
)
|
||||
resps_list = model_nar(
|
||||
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list,
|
||||
max_levels=max_nar_levels,
|
||||
sampling_temperature=nar_temp,
|
||||
sampling_min_temperature=min_nar_temp,
|
||||
sampling_top_p=top_p, sampling_top_k=top_k,
|
||||
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
||||
)
|
||||
"""
|
||||
"""
|
||||
if model_ar is not None:
|
||||
resps_list = model_ar(
|
||||
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context,
|
||||
sampling_temperature=ar_temp,
|
||||
sampling_min_temperature=min_ar_temp,
|
||||
sampling_top_p=top_p, sampling_top_k=top_k,
|
||||
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
||||
sampling_length_penalty=length_penalty,
|
||||
sampling_beam_width=beam_width,
|
||||
sampling_mirostat_tau=mirostat_tau,
|
||||
sampling_mirostat_eta=mirostat_eta,
|
||||
)
|
||||
resps_list = model_nar(
|
||||
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list,
|
||||
max_levels=max_nar_levels,
|
||||
sampling_temperature=nar_temp,
|
||||
sampling_min_temperature=min_nar_temp,
|
||||
sampling_top_p=top_p, sampling_top_k=top_k,
|
||||
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
||||
)
|
||||
elif model_len is not None:
|
||||
len_list = model_len( text_list=[phns], proms_list=[prom], max_steps=10 ) # don't need more than that
|
||||
resps_list = model_nar( text_list=[phns], proms_list=[prom], len_list=len_list,
|
||||
max_levels=max_nar_levels,
|
||||
sampling_temperature=nar_temp,
|
||||
sampling_min_temperature=min_nar_temp,
|
||||
sampling_top_p=top_p, sampling_top_k=top_k,
|
||||
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
||||
)
|
||||
else:
|
||||
raise Exception("!")
|
||||
|
||||
wav, sr = qnt.decode_to_file(resps_list[0], out_path, device=self.device)
|
||||
wavs.append(wav)
|
||||
|
|
|
@ -229,7 +229,25 @@ class NAR(Base):
|
|||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
"""
|
||||
resps_list = [ logit[-l:].argmax(dim=1) for logit, l in zip(logits, len_list) ]
|
||||
"""
|
||||
|
||||
resps_list = super().sample(
|
||||
logits=logits,
|
||||
resps_list=prev_list,
|
||||
quant_levels=quant_levels,
|
||||
|
||||
temperature=1.0 if n == 0 else sampling_temperature,
|
||||
min_temperature=sampling_min_temperature,
|
||||
top_p=sampling_top_p,
|
||||
top_k=sampling_top_k,
|
||||
repetition_penalty=sampling_repetition_penalty,
|
||||
repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
#length_penalty=sampling_length_penalty,
|
||||
#beam_width=sampling_beam_width,
|
||||
#mirostat=mirostat,
|
||||
)
|
||||
|
||||
if n == 0:
|
||||
prev_list = [ r.unsqueeze(-1).to(device) for r in resps_list ]
|
||||
|
|
Loading…
Reference in New Issue
Block a user