fixed NAR-len issues with non-english maybe (langs weren't being passed), added interface to inference in batches through tts.batched_inference (no support for rolling context/prefixes because there's no way to do that), demo page uses batched inferencing now
This commit is contained in:
parent
1f54bf5b40
commit
5d80a2d0d4
|
@ -744,6 +744,8 @@ class Inference:
|
||||||
|
|
||||||
normalize: bool = False # to-do: actually normalize input / output audio, I believe this might cause issues though
|
normalize: bool = False # to-do: actually normalize input / output audio, I believe this might cause issues though
|
||||||
|
|
||||||
|
batch_size: int = 16 # I don't know what would be a good batch size
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self):
|
def dtype(self):
|
||||||
if self.weight_dtype == "float16":
|
if self.weight_dtype == "float16":
|
||||||
|
|
142
vall_e/demo.py
142
vall_e/demo.py
|
@ -37,12 +37,40 @@ def encode(path):
|
||||||
return ""
|
return ""
|
||||||
return "data:audio/wav;base64," + base64.b64encode(open(path, "rb").read()).decode('utf-8')
|
return "data:audio/wav;base64," + base64.b64encode(open(path, "rb").read()).decode('utf-8')
|
||||||
|
|
||||||
|
def safe_inference( tts, out_path, **kwargs ):
|
||||||
|
if args.skip_existing and out_path.exists():
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
tts.inference( out_path=out_path, **kwargs )
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
print(f'Error while processing {out_path}: {e}')
|
||||||
|
|
||||||
|
def safe_batched_inference( tts, **kwargs ):
|
||||||
|
try:
|
||||||
|
tts.batched_inference( **kwargs )
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
print(f'Error while processing batch: {e}')
|
||||||
|
|
||||||
|
def process_batch( tts, inputs, kwargs={} ):
|
||||||
|
kwargs = kwargs | dict(
|
||||||
|
texts=[ x[0] for x in inputs ],
|
||||||
|
references=[ x[1] for x in inputs ],
|
||||||
|
languages=[ x[2] for x in inputs ],
|
||||||
|
out_paths=[ x[3] for x in inputs ],
|
||||||
|
)
|
||||||
|
|
||||||
|
safe_batched_inference( tts, **kwargs )
|
||||||
|
|
||||||
# Would be downright sugoi if I could incorporate this with into __main__
|
# Would be downright sugoi if I could incorporate this with into __main__
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser("VALL-E TTS Demo")
|
parser = argparse.ArgumentParser("VALL-E TTS Demo")
|
||||||
|
|
||||||
parser.add_argument("--yaml", type=Path, default=None)
|
parser.add_argument("--yaml", type=Path, default=None)
|
||||||
parser.add_argument("--model", type=Path, default=None)
|
parser.add_argument("--model", type=Path, default=None)
|
||||||
|
parser.add_argument("--batch-size", type=int, default=0)
|
||||||
|
|
||||||
parser.add_argument("--demo-dir", type=Path, default=None)
|
parser.add_argument("--demo-dir", type=Path, default=None)
|
||||||
parser.add_argument("--skip-existing", action="store_true")
|
parser.add_argument("--skip-existing", action="store_true")
|
||||||
|
@ -61,14 +89,14 @@ def main():
|
||||||
parser.add_argument("--out-path", type=Path, default=None)
|
parser.add_argument("--out-path", type=Path, default=None)
|
||||||
|
|
||||||
parser.add_argument("--max-duration", type=int, default=12 * cfg.dataset.frames_per_second)
|
parser.add_argument("--max-duration", type=int, default=12 * cfg.dataset.frames_per_second)
|
||||||
parser.add_argument("--max-steps", type=int, default=25)
|
parser.add_argument("--max-steps", type=int, default=50)
|
||||||
parser.add_argument("--max-levels", type=int, default=7)
|
parser.add_argument("--max-levels", type=int, default=7)
|
||||||
|
|
||||||
parser.add_argument("--ar-temperature", type=float, default=1.0)
|
parser.add_argument("--ar-temperature", type=float, default=1.0)
|
||||||
parser.add_argument("--nar-temperature", type=float, default=0.0)
|
parser.add_argument("--nar-temperature", type=float, default=0.0)
|
||||||
parser.add_argument("--min-ar-temperature", type=float, default=-1.0)
|
parser.add_argument("--min-ar-temperature", type=float, default=-1.0)
|
||||||
parser.add_argument("--min-nar-temperature", type=float, default=-1.0)
|
parser.add_argument("--min-nar-temperature", type=float, default=-1.0)
|
||||||
parser.add_argument("--input-prompt-length", type=float, default=3.0)
|
parser.add_argument("--input-prompt-length", type=float, default=5.0)
|
||||||
parser.add_argument("--input-prompt-prefix", action="store_true")
|
parser.add_argument("--input-prompt-prefix", action="store_true")
|
||||||
parser.add_argument("--prefix-silence", type=float, default=0.0)
|
parser.add_argument("--prefix-silence", type=float, default=0.0)
|
||||||
parser.add_argument("--cfg-strength", type=float, default=1.0)
|
parser.add_argument("--cfg-strength", type=float, default=1.0)
|
||||||
|
@ -90,18 +118,6 @@ def main():
|
||||||
parser.add_argument("--dry-base", type=float, default=1.75)
|
parser.add_argument("--dry-base", type=float, default=1.75)
|
||||||
parser.add_argument("--dry-allowed-length", type=int, default=2)
|
parser.add_argument("--dry-allowed-length", type=int, default=2)
|
||||||
|
|
||||||
parser.add_argument("--entropix-sampling", action="store_true")
|
|
||||||
|
|
||||||
parser.add_argument("--layer-skip", action="store_true")
|
|
||||||
parser.add_argument("--layer-skip-exit-layer", type=int, default=None)
|
|
||||||
parser.add_argument("--layer-skip-entropy-threshold", type=int, default=0.1)
|
|
||||||
parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=0.1)
|
|
||||||
parser.add_argument("--refine-on-stop", action="store_true")
|
|
||||||
|
|
||||||
# experimental settings
|
|
||||||
parser.add_argument("--load-from-artifact", type=Path, default=None)
|
|
||||||
parser.add_argument("--denoise-start", type=float, default=0.0)
|
|
||||||
|
|
||||||
parser.add_argument("--seed", type=int, default=None)
|
parser.add_argument("--seed", type=int, default=None)
|
||||||
|
|
||||||
parser.add_argument("--device", type=str, default=None)
|
parser.add_argument("--device", type=str, default=None)
|
||||||
|
@ -151,30 +167,6 @@ def main():
|
||||||
comparison_kwargs["disabled"]["ar_temperature"] = 0.0
|
comparison_kwargs["disabled"]["ar_temperature"] = 0.0
|
||||||
comparison_kwargs["enabled"]["use_lora"] = False
|
comparison_kwargs["enabled"]["use_lora"] = False
|
||||||
comparison_kwargs["enabled"]["ar_temperature"] = 0.95
|
comparison_kwargs["enabled"]["ar_temperature"] = 0.95
|
||||||
elif args.comparison == "entropix-sampling":
|
|
||||||
comparison_kwargs["suffix"] = "entropix_sampling"
|
|
||||||
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]
|
|
||||||
|
|
||||||
comparison_kwargs["disabled"]["entropix_sampling"] = False
|
|
||||||
comparison_kwargs["disabled"]["ar_temperature"] = args.ar_temperature
|
|
||||||
comparison_kwargs["disabled"]["top_k"] = args.top_k
|
|
||||||
comparison_kwargs["disabled"]["top_p"] = args.top_p
|
|
||||||
comparison_kwargs["enabled"]["entropix_sampling"] = True
|
|
||||||
comparison_kwargs["enabled"]["ar_temperature"] = 0.666
|
|
||||||
comparison_kwargs["enabled"]["top_k"] = 27
|
|
||||||
comparison_kwargs["enabled"]["top_p"] = 0.9
|
|
||||||
elif args.comparison == "layerskip":
|
|
||||||
comparison_kwargs["suffix"] = "layerskip"
|
|
||||||
comparison_kwargs["titles"] = [f"Without LayerSkip", "With LayerSkip"]
|
|
||||||
|
|
||||||
comparison_kwargs["disabled"]["layer_skip"] = False
|
|
||||||
comparison_kwargs["enabled"]["layer_skip"] = True
|
|
||||||
elif args.comparison == "refine-on-stop":
|
|
||||||
comparison_kwargs["suffix"] = "refine-on-stop"
|
|
||||||
comparison_kwargs["titles"] = [f"Without Ro<S>", "With Ro<S>"]
|
|
||||||
|
|
||||||
comparison_kwargs["disabled"]["refine_on_stop"] = False
|
|
||||||
comparison_kwargs["enabled"]["refine_on_stop"] = True
|
|
||||||
elif args.comparison == "ar-temp":
|
elif args.comparison == "ar-temp":
|
||||||
current_temperature = args.ar_temperature
|
current_temperature = args.ar_temperature
|
||||||
other_temperature = 1.0
|
other_temperature = 1.0
|
||||||
|
@ -254,18 +246,15 @@ def main():
|
||||||
beam_width=args.beam_width,
|
beam_width=args.beam_width,
|
||||||
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
||||||
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
|
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
|
||||||
entropix_sampling=args.entropix_sampling,
|
|
||||||
layer_skip=args.layer_skip,
|
|
||||||
layer_skip_exit_layer=args.layer_skip_exit_layer,
|
|
||||||
layer_skip_entropy_threshold=args.layer_skip_entropy_threshold,
|
|
||||||
layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold,
|
|
||||||
refine_on_stop=args.refine_on_stop,
|
|
||||||
denoise_start=args.denoise_start,
|
|
||||||
input_prompt_length=args.input_prompt_length,
|
input_prompt_length=args.input_prompt_length,
|
||||||
input_prompt_prefix=args.input_prompt_prefix,
|
input_prompt_prefix=args.input_prompt_prefix,
|
||||||
prefix_silence=args.prefix_silence,
|
prefix_silence=args.prefix_silence,
|
||||||
cfg_strength=args.cfg_strength,
|
cfg_strength=args.cfg_strength,
|
||||||
cfg_rescale=args.cfg_rescale,
|
cfg_rescale=args.cfg_rescale,
|
||||||
|
|
||||||
|
seed = args.seed if args.seed else int(time.time()),
|
||||||
|
tqdm = True,
|
||||||
|
batch_size = args.batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# replace values in our template
|
# replace values in our template
|
||||||
|
@ -326,6 +315,9 @@ def main():
|
||||||
decode_to_file( batch["proms"].to("cuda"), prompt, device="cuda" )
|
decode_to_file( batch["proms"].to("cuda"), prompt, device="cuda" )
|
||||||
decode_to_file( batch["resps"].to("cuda"), reference, device="cuda" )
|
decode_to_file( batch["resps"].to("cuda"), reference, device="cuda" )
|
||||||
|
|
||||||
|
inputs = []
|
||||||
|
outputs = []
|
||||||
|
comparison_inputs = []
|
||||||
for k, sample_dir in samples_dirs.items():
|
for k, sample_dir in samples_dirs.items():
|
||||||
if not sample_dir.exists():
|
if not sample_dir.exists():
|
||||||
continue
|
continue
|
||||||
|
@ -349,6 +341,13 @@ def main():
|
||||||
audio_samples += [ out_path_comparison ]
|
audio_samples += [ out_path_comparison ]
|
||||||
audio_samples += [ p if p.exists() else None for p in external_sources ]
|
audio_samples += [ p if p.exists() else None for p in external_sources ]
|
||||||
|
|
||||||
|
"""
|
||||||
|
# manual invocation
|
||||||
|
cmd = f'python3 -m vall_e --yaml="{args.yaml}" "{reference}" "{text}" --out-path={out_path}'
|
||||||
|
# F5
|
||||||
|
cmd = f'python inference-cli.py --model "F5-TTS" --ref_audio "{reference}" --gen_text "{text}" --output_dir "{out_path.parent}"'
|
||||||
|
"""
|
||||||
|
|
||||||
if not args.random_prompts or k == "librispeech":
|
if not args.random_prompts or k == "librispeech":
|
||||||
audio_samples += [ reference ]
|
audio_samples += [ reference ]
|
||||||
|
|
||||||
|
@ -357,51 +356,22 @@ def main():
|
||||||
audio_samples,
|
audio_samples,
|
||||||
))
|
))
|
||||||
|
|
||||||
seed = args.seed if args.seed else int(time.time())
|
# segregate comparisons into its own batch because they use different kwargs (and I do not support variadic-batched kwargs)
|
||||||
|
|
||||||
"""
|
|
||||||
# manual invocation
|
|
||||||
cmd = f'python3 -m vall_e --yaml="{args.yaml}" "{reference}" "{text}" --out-path={out_path}'
|
|
||||||
# F5
|
|
||||||
cmd = f'python inference-cli.py --model "F5-TTS" --ref_audio "{reference}" --gen_text "{text}" --output_dir "{out_path.parent}"'
|
|
||||||
"""
|
|
||||||
|
|
||||||
kwargs = dict(
|
|
||||||
text=text,
|
|
||||||
references=[prompt],
|
|
||||||
language=language,
|
|
||||||
seed=seed,
|
|
||||||
tqdm=False,
|
|
||||||
**sampling_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def safe_inference( out_path=out_path ):
|
|
||||||
if args.skip_existing and out_path.exists():
|
|
||||||
return
|
|
||||||
|
|
||||||
# swap model config swap
|
|
||||||
"""
|
|
||||||
if "dtype" in kwargs or "amp" in kwargs:
|
|
||||||
dtype = kwargs.pop("dtype", args.dtype)
|
|
||||||
amp = kwargs.pop("amp", args.amp)
|
|
||||||
|
|
||||||
del tts
|
|
||||||
tts = TTS( config=args.yaml, device=args.device, dtype=dtype, amp=amp )
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
tts.inference( out_path=out_path, **kwargs )
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
print(f'Error while processing {out_path}: {e}')
|
|
||||||
|
|
||||||
if args.comparison:
|
if args.comparison:
|
||||||
kwargs.update( comparison_kwargs["enabled"] )
|
comparison_inputs.append((text, prompt, language, out_path_comparison))
|
||||||
safe_inference(out_path_comparison)
|
|
||||||
kwargs.update( comparison_kwargs["disabled"] )
|
|
||||||
|
|
||||||
safe_inference()
|
inputs.append((text, prompt, language, out_path))
|
||||||
|
|
||||||
|
outputs.append((k, samples))
|
||||||
|
|
||||||
|
if inputs:
|
||||||
|
process_batch( tts, inputs, sampling_kwargs | (comparison_kwargs["disabled"] if args.comparison else {}) )
|
||||||
|
|
||||||
|
if comparison_inputs:
|
||||||
|
process_batch( tts, comparison_inputs, sampling_kwargs | (comparison_kwargs["enabled"] if args.comparison else {}) )
|
||||||
|
|
||||||
# collate entries into HTML
|
# collate entries into HTML
|
||||||
|
for k, samples in outputs:
|
||||||
samples = [
|
samples = [
|
||||||
f'\n\t\t\t<tr>\n\t\t\t\t<td>{text}</td>'+
|
f'\n\t\t\t<tr>\n\t\t\t\t<td>{text}</td>'+
|
||||||
"".join( [
|
"".join( [
|
||||||
|
|
|
@ -68,6 +68,7 @@ class TTS():
|
||||||
self.device = device
|
self.device = device
|
||||||
self.dtype = cfg.inference.dtype
|
self.dtype = cfg.inference.dtype
|
||||||
self.amp = amp
|
self.amp = amp
|
||||||
|
self.batch_size = cfg.inference.batch_size
|
||||||
|
|
||||||
self.model_kwargs = {}
|
self.model_kwargs = {}
|
||||||
if attention:
|
if attention:
|
||||||
|
@ -120,10 +121,13 @@ class TTS():
|
||||||
if isinstance( paths, str ):
|
if isinstance( paths, str ):
|
||||||
paths = [ Path(p) for p in paths.split(";") ]
|
paths = [ Path(p) for p in paths.split(";") ]
|
||||||
|
|
||||||
# merge inputs
|
# not already a list
|
||||||
|
if isinstance( paths, Path ):
|
||||||
|
paths = [ paths ]
|
||||||
|
|
||||||
proms = []
|
proms = []
|
||||||
|
|
||||||
|
# merge inputs
|
||||||
for path in paths:
|
for path in paths:
|
||||||
prom = qnt.encode_from_file(path)
|
prom = qnt.encode_from_file(path)
|
||||||
if hasattr( prom, "codes" ):
|
if hasattr( prom, "codes" ):
|
||||||
|
@ -185,26 +189,159 @@ class TTS():
|
||||||
modality = cfg.model.name
|
modality = cfg.model.name
|
||||||
return modality
|
return modality
|
||||||
|
|
||||||
|
# makes use of being able to batch inputs seamlessly by automatically batching
|
||||||
|
# this is NOT the default because it absolutely cannot make use of rolling context / prefixing
|
||||||
|
@torch.inference_mode()
|
||||||
|
def batched_inference(
|
||||||
|
self,
|
||||||
|
texts,
|
||||||
|
references=None,
|
||||||
|
languages=None,
|
||||||
|
text_languages=None,
|
||||||
|
out_paths=None,
|
||||||
|
**sampling_kwargs,
|
||||||
|
):
|
||||||
|
batch_size = sampling_kwargs.pop("batch_size", self.batch_size)
|
||||||
|
input_prompt_length = sampling_kwargs.pop("input_prompt_length", 0)
|
||||||
|
modality = sampling_kwargs.pop("modality", "auto")
|
||||||
|
seed = sampling_kwargs.pop("seed", None)
|
||||||
|
tqdm = sampling_kwargs.pop("tqdm", True)
|
||||||
|
use_lora = sampling_kwargs.pop("use_lora", None)
|
||||||
|
dtype = sampling_kwargs.pop("dtype", self.dtype)
|
||||||
|
amp = sampling_kwargs.pop("amp", self.amp)
|
||||||
|
|
||||||
|
model_ar = None
|
||||||
|
model_len = None
|
||||||
|
model_nar = None
|
||||||
|
|
||||||
|
for name, engine in self.engines.items():
|
||||||
|
if model_ar is None and "ar" in engine.hyper_config.capabilities:
|
||||||
|
model_ar = engine.module
|
||||||
|
if model_len is None and "len" in engine.hyper_config.capabilities:
|
||||||
|
model_len = engine.module
|
||||||
|
if model_nar is None and "nar" in engine.hyper_config.capabilities:
|
||||||
|
model_nar = engine.module
|
||||||
|
|
||||||
|
modality = self.modality( modality )
|
||||||
|
# force AR+NAR
|
||||||
|
if modality == "ar+nar":
|
||||||
|
model_len = None
|
||||||
|
# force NAR-len
|
||||||
|
elif modality == "nar-len":
|
||||||
|
model_ar = None
|
||||||
|
|
||||||
|
samples = len(texts)
|
||||||
|
# fill with null input proms
|
||||||
|
if not references:
|
||||||
|
references = [ None for _ in range(samples) ]
|
||||||
|
# fill with english
|
||||||
|
if not languages:
|
||||||
|
languages = [ "en" for _ in range(samples) ]
|
||||||
|
if not out_paths:
|
||||||
|
out_paths = [ None for _ in range(samples) ]
|
||||||
|
# use the audio language to phonemize the text
|
||||||
|
if not text_languages:
|
||||||
|
text_languages = languages
|
||||||
|
|
||||||
|
# tensorfy inputs
|
||||||
|
for i in range( samples ):
|
||||||
|
texts[i] = self.encode_text( texts[i], language=text_languages[i] )
|
||||||
|
references[i] = self.encode_audio( references[i], trim_length=input_prompt_length ) if references[i] else None
|
||||||
|
languages[i] = self.encode_lang( languages[i] )
|
||||||
|
|
||||||
|
texts[i] = to_device(texts[i], device=self.device, dtype=torch.uint8 if len(self.symmap) < 256 else torch.int16)
|
||||||
|
references[i] = to_device(references[i], device=self.device, dtype=torch.int16)
|
||||||
|
languages[i] = to_device(languages[i], device=self.device, dtype=torch.uint8)
|
||||||
|
|
||||||
|
# create batches
|
||||||
|
batches = []
|
||||||
|
buffer = ([], [], [], [])
|
||||||
|
for batch in zip( texts, references, languages, out_paths ):
|
||||||
|
# flush
|
||||||
|
if len(buffer[0]) >= batch_size:
|
||||||
|
batches.append(buffer)
|
||||||
|
buffer = ([], [], [], [])
|
||||||
|
|
||||||
|
# insert into buffer
|
||||||
|
for i, x in enumerate( batch ):
|
||||||
|
buffer[i].append(x)
|
||||||
|
|
||||||
|
# flush
|
||||||
|
if len(buffer[0]) >= batch_size:
|
||||||
|
batches.append(buffer)
|
||||||
|
buffer = ([], [], [], [])
|
||||||
|
|
||||||
|
wavs = []
|
||||||
|
for texts, proms, langs, out_paths in batches:
|
||||||
|
seed = set_seed(seed)
|
||||||
|
batch_size = len(texts)
|
||||||
|
input_kwargs = dict(
|
||||||
|
text_list=texts,
|
||||||
|
proms_list=proms,
|
||||||
|
lang_list=langs,
|
||||||
|
disable_tqdm=not tqdm,
|
||||||
|
use_lora=use_lora,
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.autocast("cuda", dtype=dtype, enabled=amp):
|
||||||
|
if model_len is not None:
|
||||||
|
# extra kwargs
|
||||||
|
duration_padding = sampling_kwargs.pop("duration_padding", 1.05)
|
||||||
|
nar_len_prefix_length = sampling_kwargs.pop("nar_len_prefix_length", 0)
|
||||||
|
|
||||||
|
len_list = model_len( **input_kwargs, task_list=["len"]*batch_size, **{"max_duration": 5} ) # "max_duration" is max tokens
|
||||||
|
|
||||||
|
# add an additional X seconds
|
||||||
|
len_list = [ int(l * duration_padding) for l in len_list ]
|
||||||
|
|
||||||
|
resps_list = model_nar( **input_kwargs, len_list=len_list, task_list=["tts"]*batch_size,
|
||||||
|
**sampling_kwargs,
|
||||||
|
)
|
||||||
|
elif model_ar is not None:
|
||||||
|
resps_list = model_ar(
|
||||||
|
**input_kwargs, task_list=["tts"]*batch_size,
|
||||||
|
**sampling_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
resps_list = model_nar(
|
||||||
|
**input_kwargs, resps_list=resps_list, task_list=["tts"]*batch_size,
|
||||||
|
**sampling_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception("!")
|
||||||
|
|
||||||
|
for resp, out_path in zip( resps_list, out_paths ):
|
||||||
|
if out_path:
|
||||||
|
wav, sr = qnt.decode_to_file(resp, out_path, device=self.device)
|
||||||
|
else:
|
||||||
|
wav, sr = qnt.decode(resp, device=self.device)
|
||||||
|
wavs.append(wav)
|
||||||
|
return wavs
|
||||||
|
|
||||||
|
# naive serial inferencing
|
||||||
|
# will automatically split a text into pieces (if requested) piece by piece
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def inference(
|
def inference(
|
||||||
self,
|
self,
|
||||||
text,
|
text,
|
||||||
references,
|
references,
|
||||||
text_language=None,
|
|
||||||
language="en",
|
language="en",
|
||||||
|
text_language=None,
|
||||||
task="tts",
|
task="tts",
|
||||||
modality="auto",
|
|
||||||
|
|
||||||
input_prompt_length = 0,
|
|
||||||
|
|
||||||
seed = None,
|
|
||||||
out_path=None,
|
out_path=None,
|
||||||
tqdm=True,
|
|
||||||
use_lora=None,
|
|
||||||
**sampling_kwargs,
|
**sampling_kwargs,
|
||||||
):
|
):
|
||||||
|
input_prompt_length = sampling_kwargs.pop("input_prompt_length", 0)
|
||||||
|
modality = sampling_kwargs.pop("modality", "auto")
|
||||||
|
seed = sampling_kwargs.pop("seed", None)
|
||||||
|
tqdm = sampling_kwargs.pop("tqdm", True)
|
||||||
|
use_lora = sampling_kwargs.pop("use_lora", None)
|
||||||
|
dtype = sampling_kwargs.pop("dtype", self.dtype)
|
||||||
|
amp = sampling_kwargs.pop("amp", self.amp)
|
||||||
|
|
||||||
if not text_language:
|
if not text_language:
|
||||||
text_language = language
|
text_language = language
|
||||||
|
|
||||||
lines = sentence_split(text, split_by=sampling_kwargs.get("split_text_by", "sentences"))
|
lines = sentence_split(text, split_by=sampling_kwargs.get("split_text_by", "sentences"))
|
||||||
|
|
||||||
wavs = []
|
wavs = []
|
||||||
|
@ -239,7 +376,7 @@ class TTS():
|
||||||
resp = to_device(resp, device=self.device, dtype=torch.int16)
|
resp = to_device(resp, device=self.device, dtype=torch.int16)
|
||||||
lang = to_device(lang, device=self.device, dtype=torch.uint8)
|
lang = to_device(lang, device=self.device, dtype=torch.uint8)
|
||||||
|
|
||||||
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
|
with torch.autocast("cuda", dtype=dtype, enabled=amp):
|
||||||
model = model_ar if model_ar is not None else model_nar
|
model = model_ar if model_ar is not None else model_nar
|
||||||
if model is not None:
|
if model is not None:
|
||||||
text_list = model(
|
text_list = model(
|
||||||
|
@ -275,14 +412,20 @@ class TTS():
|
||||||
phns = to_device(phns, device=self.device, dtype=torch.uint8 if len(self.symmap) < 256 else torch.int16)
|
phns = to_device(phns, device=self.device, dtype=torch.uint8 if len(self.symmap) < 256 else torch.int16)
|
||||||
lang = to_device(lang, device=self.device, dtype=torch.uint8)
|
lang = to_device(lang, device=self.device, dtype=torch.uint8)
|
||||||
|
|
||||||
# to-do: add in case for experimental.hf model
|
with torch.autocast("cuda", dtype=dtype, enabled=amp):
|
||||||
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
|
input_kwargs = dict(
|
||||||
|
text_list=[phns],
|
||||||
|
proms_list=[prom],
|
||||||
|
lang_list=[lang],
|
||||||
|
disable_tqdm=not tqdm,
|
||||||
|
use_lora=use_lora,
|
||||||
|
)
|
||||||
if model_len is not None:
|
if model_len is not None:
|
||||||
# extra kwargs
|
# extra kwargs
|
||||||
duration_padding = sampling_kwargs.pop("duration_padding", 1.05)
|
duration_padding = sampling_kwargs.pop("duration_padding", 1.05)
|
||||||
nar_len_prefix_length = sampling_kwargs.pop("nar_len_prefix_length", 0)
|
nar_len_prefix_length = sampling_kwargs.pop("nar_len_prefix_length", 0)
|
||||||
|
|
||||||
len_list = model_len( text_list=[phns], proms_list=[prom], task_list=["len"], disable_tqdm=not tqdm, **{"max_duration": 5} ) # "max_duration" is max tokens
|
len_list = model_len( **input_kwargs, task_list=["len"], **{"max_duration": 5} ) # "max_duration" is max tokens
|
||||||
|
|
||||||
# add an additional X seconds
|
# add an additional X seconds
|
||||||
len_list = [ int(l * duration_padding) for l in len_list ]
|
len_list = [ int(l * duration_padding) for l in len_list ]
|
||||||
|
@ -291,9 +434,7 @@ class TTS():
|
||||||
if prefix_context is not None:
|
if prefix_context is not None:
|
||||||
kwargs["prefix_context"] = prefix_context
|
kwargs["prefix_context"] = prefix_context
|
||||||
|
|
||||||
resps_list = model_nar( text_list=[phns], proms_list=[prom], len_list=len_list, task_list=["tts"],
|
resps_list = model_nar( **input_kwargs, len_list=len_list, task_list=["tts"],
|
||||||
disable_tqdm=not tqdm,
|
|
||||||
use_lora=use_lora,
|
|
||||||
**(sampling_kwargs | kwargs),
|
**(sampling_kwargs | kwargs),
|
||||||
)
|
)
|
||||||
elif model_ar is not None:
|
elif model_ar is not None:
|
||||||
|
@ -302,16 +443,12 @@ class TTS():
|
||||||
kwargs["prefix_context"] = prefix_context
|
kwargs["prefix_context"] = prefix_context
|
||||||
|
|
||||||
resps_list = model_ar(
|
resps_list = model_ar(
|
||||||
text_list=[phns], proms_list=[prom], lang_list=[lang], task_list=["tts"],
|
**input_kwargs, task_list=["tts"],
|
||||||
disable_tqdm=not tqdm,
|
|
||||||
use_lora=use_lora,
|
|
||||||
**(sampling_kwargs | kwargs),
|
**(sampling_kwargs | kwargs),
|
||||||
)
|
)
|
||||||
|
|
||||||
resps_list = model_nar(
|
resps_list = model_nar(
|
||||||
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list, task_list=["tts"],
|
**input_kwargs, resps_list=resps_list, task_list=["tts"],
|
||||||
disable_tqdm=not tqdm,
|
|
||||||
use_lora=use_lora,
|
|
||||||
**sampling_kwargs,
|
**sampling_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -622,7 +622,7 @@ class AR_NAR(Base):
|
||||||
r = [ logit[-1:].argmax(dim=1) for logit in logits ]
|
r = [ logit[-1:].argmax(dim=1) for logit in logits ]
|
||||||
# sanitize
|
# sanitize
|
||||||
for i, token in enumerate(r):
|
for i, token in enumerate(r):
|
||||||
if token > 10:
|
if token > stop_token:
|
||||||
r[i][0] = stop_token
|
r[i][0] = stop_token
|
||||||
|
|
||||||
# append tokens
|
# append tokens
|
||||||
|
|
Loading…
Reference in New Issue
Block a user