the NAR only dream is dead (it just won't work)
This commit is contained in:
parent
a9353cf9fa
commit
bcf3910a17
|
@ -1,6 +1,7 @@
|
||||||
import argparse
|
import argparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from .inference import TTS
|
from .inference import TTS
|
||||||
|
from .config import cfg
|
||||||
|
|
||||||
def path_list(arg):
|
def path_list(arg):
|
||||||
return [Path(p) for p in arg.split(";")]
|
return [Path(p) for p in arg.split(";")]
|
||||||
|
|
|
@ -141,11 +141,14 @@ class TTS():
|
||||||
sr = None
|
sr = None
|
||||||
|
|
||||||
model_ar = None
|
model_ar = None
|
||||||
|
model_len = None
|
||||||
model_nar = None
|
model_nar = None
|
||||||
|
|
||||||
for name, engine in self.engines.items():
|
for name, engine in self.engines.items():
|
||||||
if "ar" in engine.hyper_config.capabilities:
|
if "ar" in engine.hyper_config.capabilities:
|
||||||
model_ar = engine.module
|
model_ar = engine.module
|
||||||
|
if "len" in engine.hyper_config.capabilities:
|
||||||
|
model_len = engine.module
|
||||||
if "nar" in engine.hyper_config.capabilities:
|
if "nar" in engine.hyper_config.capabilities:
|
||||||
model_nar = engine.module
|
model_nar = engine.module
|
||||||
|
|
||||||
|
@ -168,12 +171,7 @@ class TTS():
|
||||||
# AR temp: 1
|
# AR temp: 1
|
||||||
# NAR temp: 0.05
|
# NAR temp: 0.05
|
||||||
# prom size: 3
|
# prom size: 3
|
||||||
|
if model_ar is not None:
|
||||||
"""
|
|
||||||
resps_list = engine(text_list=text_list, proms_list=proms_list, max_steps=max_ar_steps, sampling_temperature=ar_temp)
|
|
||||||
resps_list = engine(text_list=text_list, proms_list=proms_list, resps_list=resps_list, sampling_temperature=nar_temp)
|
|
||||||
"""
|
|
||||||
|
|
||||||
resps_list = model_ar(
|
resps_list = model_ar(
|
||||||
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context,
|
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context,
|
||||||
sampling_temperature=ar_temp,
|
sampling_temperature=ar_temp,
|
||||||
|
@ -193,8 +191,17 @@ class TTS():
|
||||||
sampling_top_p=top_p, sampling_top_k=top_k,
|
sampling_top_p=top_p, sampling_top_k=top_k,
|
||||||
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
||||||
)
|
)
|
||||||
"""
|
elif model_len is not None:
|
||||||
"""
|
len_list = model_len( text_list=[phns], proms_list=[prom], max_steps=10 ) # don't need more than that
|
||||||
|
resps_list = model_nar( text_list=[phns], proms_list=[prom], len_list=len_list,
|
||||||
|
max_levels=max_nar_levels,
|
||||||
|
sampling_temperature=nar_temp,
|
||||||
|
sampling_min_temperature=min_nar_temp,
|
||||||
|
sampling_top_p=top_p, sampling_top_k=top_k,
|
||||||
|
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception("!")
|
||||||
|
|
||||||
wav, sr = qnt.decode_to_file(resps_list[0], out_path, device=self.device)
|
wav, sr = qnt.decode_to_file(resps_list[0], out_path, device=self.device)
|
||||||
wavs.append(wav)
|
wavs.append(wav)
|
||||||
|
|
|
@ -229,7 +229,25 @@ class NAR(Base):
|
||||||
quant_levels=quant_levels,
|
quant_levels=quant_levels,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
resps_list = [ logit[-l:].argmax(dim=1) for logit, l in zip(logits, len_list) ]
|
resps_list = [ logit[-l:].argmax(dim=1) for logit, l in zip(logits, len_list) ]
|
||||||
|
"""
|
||||||
|
|
||||||
|
resps_list = super().sample(
|
||||||
|
logits=logits,
|
||||||
|
resps_list=prev_list,
|
||||||
|
quant_levels=quant_levels,
|
||||||
|
|
||||||
|
temperature=1.0 if n == 0 else sampling_temperature,
|
||||||
|
min_temperature=sampling_min_temperature,
|
||||||
|
top_p=sampling_top_p,
|
||||||
|
top_k=sampling_top_k,
|
||||||
|
repetition_penalty=sampling_repetition_penalty,
|
||||||
|
repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||||
|
#length_penalty=sampling_length_penalty,
|
||||||
|
#beam_width=sampling_beam_width,
|
||||||
|
#mirostat=mirostat,
|
||||||
|
)
|
||||||
|
|
||||||
if n == 0:
|
if n == 0:
|
||||||
prev_list = [ r.unsqueeze(-1).to(device) for r in resps_list ]
|
prev_list = [ r.unsqueeze(-1).to(device) for r in resps_list ]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user