added option to set the trim length for an input prompt
This commit is contained in:
parent
d10053d11f
commit
4f61f5c889
|
@ -19,6 +19,7 @@ def main():
|
||||||
|
|
||||||
parser.add_argument("--ar-temp", type=float, default=1.0)
|
parser.add_argument("--ar-temp", type=float, default=1.0)
|
||||||
parser.add_argument("--nar-temp", type=float, default=1.0)
|
parser.add_argument("--nar-temp", type=float, default=1.0)
|
||||||
|
parser.add_argument("--input-prompt-length", type=float, default=3.0)
|
||||||
|
|
||||||
parser.add_argument("--top-p", type=float, default=1.0)
|
parser.add_argument("--top-p", type=float, default=1.0)
|
||||||
parser.add_argument("--top-k", type=int, default=0)
|
parser.add_argument("--top-k", type=int, default=0)
|
||||||
|
@ -32,7 +33,7 @@ def main():
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device, dtype=args.dtype, amp=args.amp )
|
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device, dtype=args.dtype, amp=args.amp )
|
||||||
tts.inference( text=args.text, references=args.references, out_path=args.out_path, max_ar_steps=args.max_ar_steps, ar_temp=args.ar_temp, nar_temp=args.nar_temp, top_p=args.top_p, top_k=args.top_k, repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, length_penalty=args.length_penalty )
|
tts.inference( text=args.text, references=args.references, out_path=args.out_path, input_prompt_length=args.input_prompt_length, max_ar_steps=args.max_ar_steps, ar_temp=args.ar_temp, nar_temp=args.nar_temp, top_p=args.top_p, top_k=args.top_k, repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, length_penalty=args.length_penalty )
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -121,7 +121,7 @@ class TTS():
|
||||||
phones = [ " " if not p else p for p in content ]
|
phones = [ " " if not p else p for p in content ]
|
||||||
return torch.tensor([ 1 ] + [*map(self.symmap.get, phones)] + [ 2 ])
|
return torch.tensor([ 1 ] + [*map(self.symmap.get, phones)] + [ 2 ])
|
||||||
|
|
||||||
def encode_audio( self, paths, should_trim=True ):
|
def encode_audio( self, paths, trim_length=0.0 ):
|
||||||
# already a tensor, return it
|
# already a tensor, return it
|
||||||
if isinstance( paths, Tensor ):
|
if isinstance( paths, Tensor ):
|
||||||
return paths
|
return paths
|
||||||
|
@ -133,17 +133,17 @@ class TTS():
|
||||||
# merge inputs
|
# merge inputs
|
||||||
res = torch.cat([qnt.encode_from_file(path)[0][:, :].t().to(torch.int16) for path in paths])
|
res = torch.cat([qnt.encode_from_file(path)[0][:, :].t().to(torch.int16) for path in paths])
|
||||||
|
|
||||||
if should_trim:
|
if trim_length:
|
||||||
res = trim( res, int( 75 * cfg.dataset.prompt_duration ) )
|
res = trim( res, int( 75 * trim_length ) )
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def inference( self, text, references, max_ar_steps=6 * 75, ar_temp=0.95, nar_temp=0.5, top_p=1.0, top_k=0, repetition_penalty=1.0, repetition_penalty_decay=0.0, length_penalty=0.0, out_path=None ):
|
def inference( self, text, references, max_ar_steps=6 * 75, input_prompt_length=0.0, ar_temp=0.95, nar_temp=0.5, top_p=1.0, top_k=0, repetition_penalty=1.0, repetition_penalty_decay=0.0, length_penalty=0.0, out_path=None ):
|
||||||
if out_path is None:
|
if out_path is None:
|
||||||
out_path = f"./data/{cfg.start_time}.wav"
|
out_path = f"./data/{cfg.start_time}.wav"
|
||||||
|
|
||||||
prom = self.encode_audio( references )
|
prom = self.encode_audio( references, trim_length=input_prompt_length )
|
||||||
phns = self.encode_text( text )
|
phns = self.encode_text( text )
|
||||||
|
|
||||||
prom = to_device(prom, self.device).to(torch.int16)
|
prom = to_device(prom, self.device).to(torch.int16)
|
||||||
|
|
|
@ -57,6 +57,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||||
parser.add_argument("--text", type=str, default=kwargs["text"])
|
parser.add_argument("--text", type=str, default=kwargs["text"])
|
||||||
parser.add_argument("--references", type=str, default=kwargs["reference"])
|
parser.add_argument("--references", type=str, default=kwargs["reference"])
|
||||||
|
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
|
||||||
parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*75))
|
parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*75))
|
||||||
parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"])
|
parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"])
|
||||||
parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"])
|
parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"])
|
||||||
|
@ -75,6 +76,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
references=[args.references.split(";")],
|
references=[args.references.split(";")],
|
||||||
out_path=tmp.name,
|
out_path=tmp.name,
|
||||||
max_ar_steps=args.max_ar_steps,
|
max_ar_steps=args.max_ar_steps,
|
||||||
|
input_prompt_length=args.input_prompt_length,
|
||||||
ar_temp=args.ar_temp,
|
ar_temp=args.ar_temp,
|
||||||
nar_temp=args.nar_temp,
|
nar_temp=args.nar_temp,
|
||||||
top_p=args.top_p,
|
top_p=args.top_p,
|
||||||
|
@ -161,7 +163,9 @@ with ui:
|
||||||
layout["inference"]["outputs"]["output"] = gr.Audio(label="Output")
|
layout["inference"]["outputs"]["output"] = gr.Audio(label="Output")
|
||||||
layout["inference"]["buttons"]["inference"] = gr.Button(value="Inference")
|
layout["inference"]["buttons"]["inference"] = gr.Button(value="Inference")
|
||||||
with gr.Column(scale=7):
|
with gr.Column(scale=7):
|
||||||
|
with gr.Row():
|
||||||
layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=6, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="This sets a limit of how many steps to perform in the AR pass.")
|
layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=6, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="This sets a limit of how many steps to perform in the AR pass.")
|
||||||
|
layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR.")
|
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR.")
|
||||||
layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.25, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR.")
|
layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.25, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR.")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user