fixed NAR-len issues with non-english maybe (langs weren't being passed), added interface to inference in batches through tts.batched_inference (no support for rolling context/prefixes because there's no way to do that), demo page uses batched inferencing now
This commit is contained in:
parent
1f54bf5b40
commit
5d80a2d0d4
|
@ -744,6 +744,8 @@ class Inference:
|
|||
|
||||
normalize: bool = False # to-do: actually normalize input / output audio, I believe this might cause issues though
|
||||
|
||||
batch_size: int = 16 # I don't know what would be a good batch size
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
if self.weight_dtype == "float16":
|
||||
|
|
156
vall_e/demo.py
156
vall_e/demo.py
|
@ -37,12 +37,40 @@ def encode(path):
|
|||
return ""
|
||||
return "data:audio/wav;base64," + base64.b64encode(open(path, "rb").read()).decode('utf-8')
|
||||
|
||||
def safe_inference( tts, out_path, **kwargs ):
|
||||
if args.skip_existing and out_path.exists():
|
||||
return
|
||||
|
||||
try:
|
||||
tts.inference( out_path=out_path, **kwargs )
|
||||
except Exception as e:
|
||||
raise e
|
||||
print(f'Error while processing {out_path}: {e}')
|
||||
|
||||
def safe_batched_inference( tts, **kwargs ):
|
||||
try:
|
||||
tts.batched_inference( **kwargs )
|
||||
except Exception as e:
|
||||
raise e
|
||||
print(f'Error while processing batch: {e}')
|
||||
|
||||
def process_batch( tts, inputs, kwargs={} ):
|
||||
kwargs = kwargs | dict(
|
||||
texts=[ x[0] for x in inputs ],
|
||||
references=[ x[1] for x in inputs ],
|
||||
languages=[ x[2] for x in inputs ],
|
||||
out_paths=[ x[3] for x in inputs ],
|
||||
)
|
||||
|
||||
safe_batched_inference( tts, **kwargs )
|
||||
|
||||
# Would be downright sugoi if I could incorporate this with into __main__
|
||||
def main():
|
||||
parser = argparse.ArgumentParser("VALL-E TTS Demo")
|
||||
|
||||
parser.add_argument("--yaml", type=Path, default=None)
|
||||
parser.add_argument("--model", type=Path, default=None)
|
||||
parser.add_argument("--batch-size", type=int, default=0)
|
||||
|
||||
parser.add_argument("--demo-dir", type=Path, default=None)
|
||||
parser.add_argument("--skip-existing", action="store_true")
|
||||
|
@ -61,14 +89,14 @@ def main():
|
|||
parser.add_argument("--out-path", type=Path, default=None)
|
||||
|
||||
parser.add_argument("--max-duration", type=int, default=12 * cfg.dataset.frames_per_second)
|
||||
parser.add_argument("--max-steps", type=int, default=25)
|
||||
parser.add_argument("--max-steps", type=int, default=50)
|
||||
parser.add_argument("--max-levels", type=int, default=7)
|
||||
|
||||
parser.add_argument("--ar-temperature", type=float, default=1.0)
|
||||
parser.add_argument("--nar-temperature", type=float, default=0.0)
|
||||
parser.add_argument("--min-ar-temperature", type=float, default=-1.0)
|
||||
parser.add_argument("--min-nar-temperature", type=float, default=-1.0)
|
||||
parser.add_argument("--input-prompt-length", type=float, default=3.0)
|
||||
parser.add_argument("--input-prompt-length", type=float, default=5.0)
|
||||
parser.add_argument("--input-prompt-prefix", action="store_true")
|
||||
parser.add_argument("--prefix-silence", type=float, default=0.0)
|
||||
parser.add_argument("--cfg-strength", type=float, default=1.0)
|
||||
|
@ -90,18 +118,6 @@ def main():
|
|||
parser.add_argument("--dry-base", type=float, default=1.75)
|
||||
parser.add_argument("--dry-allowed-length", type=int, default=2)
|
||||
|
||||
parser.add_argument("--entropix-sampling", action="store_true")
|
||||
|
||||
parser.add_argument("--layer-skip", action="store_true")
|
||||
parser.add_argument("--layer-skip-exit-layer", type=int, default=None)
|
||||
parser.add_argument("--layer-skip-entropy-threshold", type=int, default=0.1)
|
||||
parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=0.1)
|
||||
parser.add_argument("--refine-on-stop", action="store_true")
|
||||
|
||||
# experimental settings
|
||||
parser.add_argument("--load-from-artifact", type=Path, default=None)
|
||||
parser.add_argument("--denoise-start", type=float, default=0.0)
|
||||
|
||||
parser.add_argument("--seed", type=int, default=None)
|
||||
|
||||
parser.add_argument("--device", type=str, default=None)
|
||||
|
@ -151,30 +167,6 @@ def main():
|
|||
comparison_kwargs["disabled"]["ar_temperature"] = 0.0
|
||||
comparison_kwargs["enabled"]["use_lora"] = False
|
||||
comparison_kwargs["enabled"]["ar_temperature"] = 0.95
|
||||
elif args.comparison == "entropix-sampling":
|
||||
comparison_kwargs["suffix"] = "entropix_sampling"
|
||||
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]
|
||||
|
||||
comparison_kwargs["disabled"]["entropix_sampling"] = False
|
||||
comparison_kwargs["disabled"]["ar_temperature"] = args.ar_temperature
|
||||
comparison_kwargs["disabled"]["top_k"] = args.top_k
|
||||
comparison_kwargs["disabled"]["top_p"] = args.top_p
|
||||
comparison_kwargs["enabled"]["entropix_sampling"] = True
|
||||
comparison_kwargs["enabled"]["ar_temperature"] = 0.666
|
||||
comparison_kwargs["enabled"]["top_k"] = 27
|
||||
comparison_kwargs["enabled"]["top_p"] = 0.9
|
||||
elif args.comparison == "layerskip":
|
||||
comparison_kwargs["suffix"] = "layerskip"
|
||||
comparison_kwargs["titles"] = [f"Without LayerSkip", "With LayerSkip"]
|
||||
|
||||
comparison_kwargs["disabled"]["layer_skip"] = False
|
||||
comparison_kwargs["enabled"]["layer_skip"] = True
|
||||
elif args.comparison == "refine-on-stop":
|
||||
comparison_kwargs["suffix"] = "refine-on-stop"
|
||||
comparison_kwargs["titles"] = [f"Without Ro<S>", "With Ro<S>"]
|
||||
|
||||
comparison_kwargs["disabled"]["refine_on_stop"] = False
|
||||
comparison_kwargs["enabled"]["refine_on_stop"] = True
|
||||
elif args.comparison == "ar-temp":
|
||||
current_temperature = args.ar_temperature
|
||||
other_temperature = 1.0
|
||||
|
@ -254,18 +246,15 @@ def main():
|
|||
beam_width=args.beam_width,
|
||||
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
||||
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
|
||||
entropix_sampling=args.entropix_sampling,
|
||||
layer_skip=args.layer_skip,
|
||||
layer_skip_exit_layer=args.layer_skip_exit_layer,
|
||||
layer_skip_entropy_threshold=args.layer_skip_entropy_threshold,
|
||||
layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold,
|
||||
refine_on_stop=args.refine_on_stop,
|
||||
denoise_start=args.denoise_start,
|
||||
input_prompt_length=args.input_prompt_length,
|
||||
input_prompt_prefix=args.input_prompt_prefix,
|
||||
prefix_silence=args.prefix_silence,
|
||||
cfg_strength=args.cfg_strength,
|
||||
cfg_rescale=args.cfg_rescale,
|
||||
|
||||
seed = args.seed if args.seed else int(time.time()),
|
||||
tqdm = True,
|
||||
batch_size = args.batch_size,
|
||||
)
|
||||
|
||||
# replace values in our template
|
||||
|
@ -326,6 +315,9 @@ def main():
|
|||
decode_to_file( batch["proms"].to("cuda"), prompt, device="cuda" )
|
||||
decode_to_file( batch["resps"].to("cuda"), reference, device="cuda" )
|
||||
|
||||
inputs = []
|
||||
outputs = []
|
||||
comparison_inputs = []
|
||||
for k, sample_dir in samples_dirs.items():
|
||||
if not sample_dir.exists():
|
||||
continue
|
||||
|
@ -349,6 +341,13 @@ def main():
|
|||
audio_samples += [ out_path_comparison ]
|
||||
audio_samples += [ p if p.exists() else None for p in external_sources ]
|
||||
|
||||
"""
|
||||
# manual invocation
|
||||
cmd = f'python3 -m vall_e --yaml="{args.yaml}" "{reference}" "{text}" --out-path={out_path}'
|
||||
# F5
|
||||
cmd = f'python inference-cli.py --model "F5-TTS" --ref_audio "{reference}" --gen_text "{text}" --output_dir "{out_path.parent}"'
|
||||
"""
|
||||
|
||||
if not args.random_prompts or k == "librispeech":
|
||||
audio_samples += [ reference ]
|
||||
|
||||
|
@ -357,51 +356,22 @@ def main():
|
|||
audio_samples,
|
||||
))
|
||||
|
||||
seed = args.seed if args.seed else int(time.time())
|
||||
|
||||
"""
|
||||
# manual invocation
|
||||
cmd = f'python3 -m vall_e --yaml="{args.yaml}" "{reference}" "{text}" --out-path={out_path}'
|
||||
# F5
|
||||
cmd = f'python inference-cli.py --model "F5-TTS" --ref_audio "{reference}" --gen_text "{text}" --output_dir "{out_path.parent}"'
|
||||
"""
|
||||
|
||||
kwargs = dict(
|
||||
text=text,
|
||||
references=[prompt],
|
||||
language=language,
|
||||
seed=seed,
|
||||
tqdm=False,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
|
||||
def safe_inference( out_path=out_path ):
|
||||
if args.skip_existing and out_path.exists():
|
||||
return
|
||||
|
||||
# swap model config swap
|
||||
"""
|
||||
if "dtype" in kwargs or "amp" in kwargs:
|
||||
dtype = kwargs.pop("dtype", args.dtype)
|
||||
amp = kwargs.pop("amp", args.amp)
|
||||
|
||||
del tts
|
||||
tts = TTS( config=args.yaml, device=args.device, dtype=dtype, amp=amp )
|
||||
"""
|
||||
try:
|
||||
tts.inference( out_path=out_path, **kwargs )
|
||||
except Exception as e:
|
||||
raise e
|
||||
print(f'Error while processing {out_path}: {e}')
|
||||
|
||||
# segregate comparisons into its own batch because they use different kwargs (and I do not support variadic-batched kwargs)
|
||||
if args.comparison:
|
||||
kwargs.update( comparison_kwargs["enabled"] )
|
||||
safe_inference(out_path_comparison)
|
||||
kwargs.update( comparison_kwargs["disabled"] )
|
||||
comparison_inputs.append((text, prompt, language, out_path_comparison))
|
||||
|
||||
safe_inference()
|
||||
inputs.append((text, prompt, language, out_path))
|
||||
|
||||
# collate entries into HTML
|
||||
outputs.append((k, samples))
|
||||
|
||||
if inputs:
|
||||
process_batch( tts, inputs, sampling_kwargs | (comparison_kwargs["disabled"] if args.comparison else {}) )
|
||||
|
||||
if comparison_inputs:
|
||||
process_batch( tts, comparison_inputs, sampling_kwargs | (comparison_kwargs["enabled"] if args.comparison else {}) )
|
||||
|
||||
# collate entries into HTML
|
||||
for k, samples in outputs:
|
||||
samples = [
|
||||
f'\n\t\t\t<tr>\n\t\t\t\t<td>{text}</td>'+
|
||||
"".join( [
|
||||
|
@ -415,12 +385,12 @@ def main():
|
|||
# write audio into template
|
||||
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
|
||||
|
||||
if args.comparison:
|
||||
disabled, enabled = comparison_kwargs["titles"]
|
||||
if args.random_prompts:
|
||||
html = html.replace("<th>Our VALL-E</th>\n\t\t\t\t\t<th>Ground Truth</th>", f"<th>Our VALL-E ({disabled})</th>\n\t\t\t\t\t<th>Our VALL-E ({enabled})</th>")
|
||||
else:
|
||||
html = html.replace("<th>Our VALL-E</th>", f"<th>Our VALL-E ({disabled})</th>\n\t\t\t\t\t<th>Our VALL-E ({enabled})</th>")
|
||||
if args.comparison:
|
||||
disabled, enabled = comparison_kwargs["titles"]
|
||||
if args.random_prompts:
|
||||
html = html.replace("<th>Our VALL-E</th>\n\t\t\t\t\t<th>Ground Truth</th>", f"<th>Our VALL-E ({disabled})</th>\n\t\t\t\t\t<th>Our VALL-E ({enabled})</th>")
|
||||
else:
|
||||
html = html.replace("<th>Our VALL-E</th>", f"<th>Our VALL-E ({disabled})</th>\n\t\t\t\t\t<th>Our VALL-E ({enabled})</th>")
|
||||
|
||||
# write demo page
|
||||
open( args.demo_dir / args.output_filename, "w", encoding="utf-8" ).write( html )
|
||||
|
|
|
@ -68,6 +68,7 @@ class TTS():
|
|||
self.device = device
|
||||
self.dtype = cfg.inference.dtype
|
||||
self.amp = amp
|
||||
self.batch_size = cfg.inference.batch_size
|
||||
|
||||
self.model_kwargs = {}
|
||||
if attention:
|
||||
|
@ -120,10 +121,13 @@ class TTS():
|
|||
if isinstance( paths, str ):
|
||||
paths = [ Path(p) for p in paths.split(";") ]
|
||||
|
||||
# merge inputs
|
||||
# not already a list
|
||||
if isinstance( paths, Path ):
|
||||
paths = [ paths ]
|
||||
|
||||
proms = []
|
||||
|
||||
# merge inputs
|
||||
for path in paths:
|
||||
prom = qnt.encode_from_file(path)
|
||||
if hasattr( prom, "codes" ):
|
||||
|
@ -185,26 +189,159 @@ class TTS():
|
|||
modality = cfg.model.name
|
||||
return modality
|
||||
|
||||
# makes use of being able to batch inputs seamlessly by automatically batching
|
||||
# this is NOT the default because it absolutely cannot make use of rolling context / prefixing
|
||||
@torch.inference_mode()
|
||||
def batched_inference(
|
||||
self,
|
||||
texts,
|
||||
references=None,
|
||||
languages=None,
|
||||
text_languages=None,
|
||||
out_paths=None,
|
||||
**sampling_kwargs,
|
||||
):
|
||||
batch_size = sampling_kwargs.pop("batch_size", self.batch_size)
|
||||
input_prompt_length = sampling_kwargs.pop("input_prompt_length", 0)
|
||||
modality = sampling_kwargs.pop("modality", "auto")
|
||||
seed = sampling_kwargs.pop("seed", None)
|
||||
tqdm = sampling_kwargs.pop("tqdm", True)
|
||||
use_lora = sampling_kwargs.pop("use_lora", None)
|
||||
dtype = sampling_kwargs.pop("dtype", self.dtype)
|
||||
amp = sampling_kwargs.pop("amp", self.amp)
|
||||
|
||||
model_ar = None
|
||||
model_len = None
|
||||
model_nar = None
|
||||
|
||||
for name, engine in self.engines.items():
|
||||
if model_ar is None and "ar" in engine.hyper_config.capabilities:
|
||||
model_ar = engine.module
|
||||
if model_len is None and "len" in engine.hyper_config.capabilities:
|
||||
model_len = engine.module
|
||||
if model_nar is None and "nar" in engine.hyper_config.capabilities:
|
||||
model_nar = engine.module
|
||||
|
||||
modality = self.modality( modality )
|
||||
# force AR+NAR
|
||||
if modality == "ar+nar":
|
||||
model_len = None
|
||||
# force NAR-len
|
||||
elif modality == "nar-len":
|
||||
model_ar = None
|
||||
|
||||
samples = len(texts)
|
||||
# fill with null input proms
|
||||
if not references:
|
||||
references = [ None for _ in range(samples) ]
|
||||
# fill with english
|
||||
if not languages:
|
||||
languages = [ "en" for _ in range(samples) ]
|
||||
if not out_paths:
|
||||
out_paths = [ None for _ in range(samples) ]
|
||||
# use the audio language to phonemize the text
|
||||
if not text_languages:
|
||||
text_languages = languages
|
||||
|
||||
# tensorfy inputs
|
||||
for i in range( samples ):
|
||||
texts[i] = self.encode_text( texts[i], language=text_languages[i] )
|
||||
references[i] = self.encode_audio( references[i], trim_length=input_prompt_length ) if references[i] else None
|
||||
languages[i] = self.encode_lang( languages[i] )
|
||||
|
||||
texts[i] = to_device(texts[i], device=self.device, dtype=torch.uint8 if len(self.symmap) < 256 else torch.int16)
|
||||
references[i] = to_device(references[i], device=self.device, dtype=torch.int16)
|
||||
languages[i] = to_device(languages[i], device=self.device, dtype=torch.uint8)
|
||||
|
||||
# create batches
|
||||
batches = []
|
||||
buffer = ([], [], [], [])
|
||||
for batch in zip( texts, references, languages, out_paths ):
|
||||
# flush
|
||||
if len(buffer[0]) >= batch_size:
|
||||
batches.append(buffer)
|
||||
buffer = ([], [], [], [])
|
||||
|
||||
# insert into buffer
|
||||
for i, x in enumerate( batch ):
|
||||
buffer[i].append(x)
|
||||
|
||||
# flush
|
||||
if len(buffer[0]) >= batch_size:
|
||||
batches.append(buffer)
|
||||
buffer = ([], [], [], [])
|
||||
|
||||
wavs = []
|
||||
for texts, proms, langs, out_paths in batches:
|
||||
seed = set_seed(seed)
|
||||
batch_size = len(texts)
|
||||
input_kwargs = dict(
|
||||
text_list=texts,
|
||||
proms_list=proms,
|
||||
lang_list=langs,
|
||||
disable_tqdm=not tqdm,
|
||||
use_lora=use_lora,
|
||||
)
|
||||
|
||||
with torch.autocast("cuda", dtype=dtype, enabled=amp):
|
||||
if model_len is not None:
|
||||
# extra kwargs
|
||||
duration_padding = sampling_kwargs.pop("duration_padding", 1.05)
|
||||
nar_len_prefix_length = sampling_kwargs.pop("nar_len_prefix_length", 0)
|
||||
|
||||
len_list = model_len( **input_kwargs, task_list=["len"]*batch_size, **{"max_duration": 5} ) # "max_duration" is max tokens
|
||||
|
||||
# add an additional X seconds
|
||||
len_list = [ int(l * duration_padding) for l in len_list ]
|
||||
|
||||
resps_list = model_nar( **input_kwargs, len_list=len_list, task_list=["tts"]*batch_size,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
elif model_ar is not None:
|
||||
resps_list = model_ar(
|
||||
**input_kwargs, task_list=["tts"]*batch_size,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
|
||||
resps_list = model_nar(
|
||||
**input_kwargs, resps_list=resps_list, task_list=["tts"]*batch_size,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
else:
|
||||
raise Exception("!")
|
||||
|
||||
for resp, out_path in zip( resps_list, out_paths ):
|
||||
if out_path:
|
||||
wav, sr = qnt.decode_to_file(resp, out_path, device=self.device)
|
||||
else:
|
||||
wav, sr = qnt.decode(resp, device=self.device)
|
||||
wavs.append(wav)
|
||||
return wavs
|
||||
|
||||
# naive serial inferencing
|
||||
# will automatically split a text into pieces (if requested) piece by piece
|
||||
@torch.inference_mode()
|
||||
def inference(
|
||||
self,
|
||||
text,
|
||||
references,
|
||||
text_language=None,
|
||||
language="en",
|
||||
text_language=None,
|
||||
task="tts",
|
||||
modality="auto",
|
||||
|
||||
input_prompt_length = 0,
|
||||
|
||||
seed = None,
|
||||
out_path=None,
|
||||
tqdm=True,
|
||||
use_lora=None,
|
||||
**sampling_kwargs,
|
||||
):
|
||||
input_prompt_length = sampling_kwargs.pop("input_prompt_length", 0)
|
||||
modality = sampling_kwargs.pop("modality", "auto")
|
||||
seed = sampling_kwargs.pop("seed", None)
|
||||
tqdm = sampling_kwargs.pop("tqdm", True)
|
||||
use_lora = sampling_kwargs.pop("use_lora", None)
|
||||
dtype = sampling_kwargs.pop("dtype", self.dtype)
|
||||
amp = sampling_kwargs.pop("amp", self.amp)
|
||||
|
||||
if not text_language:
|
||||
text_language = language
|
||||
|
||||
lines = sentence_split(text, split_by=sampling_kwargs.get("split_text_by", "sentences"))
|
||||
|
||||
wavs = []
|
||||
|
@ -239,7 +376,7 @@ class TTS():
|
|||
resp = to_device(resp, device=self.device, dtype=torch.int16)
|
||||
lang = to_device(lang, device=self.device, dtype=torch.uint8)
|
||||
|
||||
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
|
||||
with torch.autocast("cuda", dtype=dtype, enabled=amp):
|
||||
model = model_ar if model_ar is not None else model_nar
|
||||
if model is not None:
|
||||
text_list = model(
|
||||
|
@ -275,14 +412,20 @@ class TTS():
|
|||
phns = to_device(phns, device=self.device, dtype=torch.uint8 if len(self.symmap) < 256 else torch.int16)
|
||||
lang = to_device(lang, device=self.device, dtype=torch.uint8)
|
||||
|
||||
# to-do: add in case for experimental.hf model
|
||||
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
|
||||
with torch.autocast("cuda", dtype=dtype, enabled=amp):
|
||||
input_kwargs = dict(
|
||||
text_list=[phns],
|
||||
proms_list=[prom],
|
||||
lang_list=[lang],
|
||||
disable_tqdm=not tqdm,
|
||||
use_lora=use_lora,
|
||||
)
|
||||
if model_len is not None:
|
||||
# extra kwargs
|
||||
duration_padding = sampling_kwargs.pop("duration_padding", 1.05)
|
||||
nar_len_prefix_length = sampling_kwargs.pop("nar_len_prefix_length", 0)
|
||||
|
||||
len_list = model_len( text_list=[phns], proms_list=[prom], task_list=["len"], disable_tqdm=not tqdm, **{"max_duration": 5} ) # "max_duration" is max tokens
|
||||
len_list = model_len( **input_kwargs, task_list=["len"], **{"max_duration": 5} ) # "max_duration" is max tokens
|
||||
|
||||
# add an additional X seconds
|
||||
len_list = [ int(l * duration_padding) for l in len_list ]
|
||||
|
@ -291,9 +434,7 @@ class TTS():
|
|||
if prefix_context is not None:
|
||||
kwargs["prefix_context"] = prefix_context
|
||||
|
||||
resps_list = model_nar( text_list=[phns], proms_list=[prom], len_list=len_list, task_list=["tts"],
|
||||
disable_tqdm=not tqdm,
|
||||
use_lora=use_lora,
|
||||
resps_list = model_nar( **input_kwargs, len_list=len_list, task_list=["tts"],
|
||||
**(sampling_kwargs | kwargs),
|
||||
)
|
||||
elif model_ar is not None:
|
||||
|
@ -302,16 +443,12 @@ class TTS():
|
|||
kwargs["prefix_context"] = prefix_context
|
||||
|
||||
resps_list = model_ar(
|
||||
text_list=[phns], proms_list=[prom], lang_list=[lang], task_list=["tts"],
|
||||
disable_tqdm=not tqdm,
|
||||
use_lora=use_lora,
|
||||
**input_kwargs, task_list=["tts"],
|
||||
**(sampling_kwargs | kwargs),
|
||||
)
|
||||
|
||||
resps_list = model_nar(
|
||||
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list, task_list=["tts"],
|
||||
disable_tqdm=not tqdm,
|
||||
use_lora=use_lora,
|
||||
**input_kwargs, resps_list=resps_list, task_list=["tts"],
|
||||
**sampling_kwargs,
|
||||
)
|
||||
else:
|
||||
|
|
|
@ -622,7 +622,7 @@ class AR_NAR(Base):
|
|||
r = [ logit[-1:].argmax(dim=1) for logit in logits ]
|
||||
# sanitize
|
||||
for i, token in enumerate(r):
|
||||
if token > 10:
|
||||
if token > stop_token:
|
||||
r[i][0] = stop_token
|
||||
|
||||
# append tokens
|
||||
|
|
Loading…
Reference in New Issue
Block a user