added option to limit (or exceed) inferenced RVQ-bin levels through the NAR

This commit is contained in:
mrq 2023-09-10 13:50:13 -05:00
parent c74fe2f718
commit ba71020318
5 changed files with 12 additions and 6 deletions

View File

@ -16,6 +16,7 @@ def main():
parser.add_argument("--nar-ckpt", type=Path, default=None) parser.add_argument("--nar-ckpt", type=Path, default=None)
parser.add_argument("--max-ar-steps", type=int, default=6 * 75) parser.add_argument("--max-ar-steps", type=int, default=6 * 75)
parser.add_argument("--max-nar-levels", type=int, default=7)
parser.add_argument("--ar-temp", type=float, default=1.0) parser.add_argument("--ar-temp", type=float, default=1.0)
parser.add_argument("--nar-temp", type=float, default=1.0) parser.add_argument("--nar-temp", type=float, default=1.0)
@ -33,7 +34,7 @@ def main():
args = parser.parse_args() args = parser.parse_args()
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device, dtype=args.dtype, amp=args.amp ) tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device, dtype=args.dtype, amp=args.amp )
tts.inference( text=args.text, references=args.references, out_path=args.out_path, input_prompt_length=args.input_prompt_length, max_ar_steps=args.max_ar_steps, ar_temp=args.ar_temp, nar_temp=args.nar_temp, top_p=args.top_p, top_k=args.top_k, repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, length_penalty=args.length_penalty ) tts.inference( text=args.text, references=args.references, out_path=args.out_path, input_prompt_length=args.input_prompt_length, max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels, ar_temp=args.ar_temp, nar_temp=args.nar_temp, top_p=args.top_p, top_k=args.top_k, repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, length_penalty=args.length_penalty )
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -139,7 +139,7 @@ class TTS():
return res return res
@torch.inference_mode() @torch.inference_mode()
def inference( self, text, references, max_ar_steps=6 * 75, input_prompt_length=0.0, ar_temp=0.95, nar_temp=0.5, top_p=1.0, top_k=0, repetition_penalty=1.0, repetition_penalty_decay=0.0, length_penalty=0.0, out_path=None ): def inference( self, text, references, max_ar_steps=6 * 75, max_nar_levels=7, input_prompt_length=0.0, ar_temp=0.95, nar_temp=0.5, top_p=1.0, top_k=0, repetition_penalty=1.0, repetition_penalty_decay=0.0, length_penalty=0.0, out_path=None ):
if out_path is None: if out_path is None:
out_path = f"./data/{cfg.start_time}.wav" out_path = f"./data/{cfg.start_time}.wav"
@ -152,7 +152,7 @@ class TTS():
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp): with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_length_penalty=length_penalty) resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_length_penalty=length_penalty)
resps_list = [r.unsqueeze(-1) for r in resps_list] resps_list = [r.unsqueeze(-1) for r in resps_list]
resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, sampling_temperature=nar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_length_penalty=length_penalty) resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, max_levels=max_nar_levels, sampling_temperature=nar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_length_penalty=length_penalty)
wav, sr = qnt.decode_to_file(resps_list[0], out_path, device=self.device) wav, sr = qnt.decode_to_file(resps_list[0], out_path, device=self.device)

View File

@ -70,6 +70,7 @@ class AR_NAR(Base):
proms_list: list[Tensor], proms_list: list[Tensor],
resps_list: list[Tensor] | None = None, resps_list: list[Tensor] | None = None,
max_steps: int = 1000, max_steps: int = 1000,
max_levels: int = 7,
sampling_temperature: float = 0.0, sampling_temperature: float = 0.0,
sampling_top_k: int = -100, sampling_top_k: int = -100,
sampling_top_p: float = 1.0, sampling_top_p: float = 1.0,
@ -87,7 +88,7 @@ class AR_NAR(Base):
# is training # is training
if n_levels == self.n_resp_levels: if n_levels == self.n_resp_levels:
if random.random() < 0.25: if random.random() < 0.95:
quant_levels = None quant_levels = None
targ_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels targ_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels
@ -114,7 +115,7 @@ class AR_NAR(Base):
while True: while True:
level = prev_list[0].shape[-1] level = prev_list[0].shape[-1]
if level >= self.n_resp_levels: if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels
break break
quant_levels = torch.full((len(text_list),), level, device=device) quant_levels = torch.full((len(text_list),), level, device=device)

View File

@ -56,6 +56,7 @@ class NAR(Base):
text_list: list[Tensor], text_list: list[Tensor],
proms_list: list[Tensor], proms_list: list[Tensor],
resps_list: list[Tensor], resps_list: list[Tensor],
max_levels: int = 7,
sampling_temperature: float = 0.2, sampling_temperature: float = 0.2,
sampling_top_k: int = -100, sampling_top_k: int = -100,
sampling_top_p: float = 1.0, sampling_top_p: float = 1.0,
@ -106,7 +107,7 @@ class NAR(Base):
while True: while True:
level = prev_list[0].shape[-1] - 1 level = prev_list[0].shape[-1] - 1
if level >= self.n_resp_levels: if level >= max_levels: # min(max_levels, self.n_resp_levels): # commented out to experiment with exceeding trained levels
break break
quant_levels = torch.full((len(text_list),), level, device=device) quant_levels = torch.full((len(text_list),), level, device=device)

View File

@ -69,6 +69,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--references", type=str, default=kwargs["reference"]) parser.add_argument("--references", type=str, default=kwargs["reference"])
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"]) parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*75)) parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*75))
parser.add_argument("--max-nar-levels", type=int, default=kwargs["max-nar-levels"])
parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"]) parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"])
parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"]) parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"])
parser.add_argument("--top-p", type=float, default=kwargs["top-p"]) parser.add_argument("--top-p", type=float, default=kwargs["top-p"])
@ -87,6 +88,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
references=[args.references.split(";")], references=[args.references.split(";")],
out_path=tmp.name, out_path=tmp.name,
max_ar_steps=args.max_ar_steps, max_ar_steps=args.max_ar_steps,
max_nar_levels=args.max_nar_levels,
input_prompt_length=args.input_prompt_length, input_prompt_length=args.input_prompt_length,
ar_temp=args.ar_temp, ar_temp=args.ar_temp,
nar_temp=args.nar_temp, nar_temp=args.nar_temp,
@ -176,6 +178,7 @@ with ui:
with gr.Column(scale=7): with gr.Column(scale=7):
with gr.Row(): with gr.Row():
layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=6, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.") layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=6, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=3, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.") layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
with gr.Row(): with gr.Row():
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR.") layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR.")