diff --git a/vall_e/demo.py b/vall_e/demo.py index f5f9d5d..782183b 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -95,7 +95,7 @@ def main(): parser.add_argument("--out-path", type=Path, default=None) parser.add_argument("--max-duration", type=int, default=12 * cfg.dataset.frames_per_second) - parser.add_argument("--max-steps", type=int, default=50) + parser.add_argument("--max-steps", type=int, default=30) parser.add_argument("--max-levels", type=int, default=7) parser.add_argument("--ar-temperature", type=float, default=1.0) diff --git a/vall_e/inference.py b/vall_e/inference.py index eebff5e..fb6c40b 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -13,6 +13,8 @@ from pathlib import Path from .emb import g2p, qnt from .emb.qnt import trim, trim_random, unload_model, repeat_extend_audio +from .emb.transcribe import transcribe + from .utils import to_device, set_seed, clamp, wrapper as ml from .config import cfg, Config @@ -118,7 +120,7 @@ class TTS(): return torch.tensor([ id ]) # to-do: trim before quantizing, instead of after - def encode_audio( self, paths, trim_length=5.0 ): + def encode_audio( self, paths, trim_length=0.0 ): # already a tensor, return it if isinstance( paths, Tensor ): return paths @@ -357,6 +359,12 @@ class TTS(): use_lora = sampling_kwargs.pop("use_lora", None) dtype = sampling_kwargs.pop("dtype", self.dtype) amp = sampling_kwargs.pop("amp", self.amp) + + voice_convert = sampling_kwargs.pop("voice_convert", None) + + # transcribe from audio to voice convert from + if voice_convert is not None and not text: + text = transcribe( voice_convert, model_name="openai/whisper-base", align=False )["text"] lines = sentence_split(text, split_by=sampling_kwargs.get("split_text_by", "sentences")) @@ -430,6 +438,7 @@ class TTS(): if auto_text_lang: text_language = deduced_language + vc_utterance = self.encode_audio( voice_convert, trim_length=0 ) if voice_convert else None prom = self.encode_audio( references, trim_length=input_prompt_length ) if references else None phns = self.encode_text( line, language=text_language ) lang = self.encode_lang( language ) @@ -457,6 +466,8 @@ class TTS(): kwargs = {} if prefix_context is not None: kwargs["prefix_context"] = prefix_context + if vc_utterance is not None: + kwargs["vc_list"] = [ vc_utterance ] resps_list = model_nar( **input_kwargs, len_list=len_list, task_list=["tts"], **(sampling_kwargs | kwargs), diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 2dc4fa5..4b57846 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -259,9 +259,12 @@ class AR_NAR(Base): max_steps = math.floor(max_steps * (end_noise - start_noise)) # to specify the initial mask used - mask_list = sampling_kwargs.pop("mask_list", None) - if mask_list is not None: - len_list = [ x.shape[0] for x in mask_list ] + vc_list = sampling_kwargs.pop("vc_list", None) + vc_threshold = sampling_kwargs.pop("vc_threshold", 0.25) + vc_mask_p = sampling_kwargs.pop("vc_mask_p", 0.25) + if vc_list is not None: + vc_list = [ x if x.dim() == 1 else x[:, 0] for x in vc_list ] + len_list = [ x.shape[0] for x in vc_list ] len_list = [ clamp(l, min_length, max_length) for l in len_list ] @@ -305,16 +308,24 @@ class AR_NAR(Base): remask_p = 1.0 / (max_steps * 2) if remasking else 0 # pick the worst scoring tokens to mask off masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ] - if mask_list is None: + # normal masking + if vc_list is None or timestep >= vc_threshold: # mask off inputs resps_list = [ resp.scatter(0, indices, self.stop_token) for resp, indices in zip( resps_list, masked_indices ) ] # boolean mask is_masked = [ resps == self.stop_token for resps in resps_list ] else: - # mask off inputs - resps_list = [ resp.scatter(0, indices, mask) for resp, indices, mask in zip( resps_list, masked_indices, mask_list ) ] + # mask off a random portion of the target + rand_mask_list = [ torch.rand(mask.shape).to(device=device) < vc_mask_p for mask in vc_list ] + half_mask_list = [ torch.where( rand_mask, self.stop_token, mask.clone() ) for mask, rand_mask in zip( vc_list, rand_mask_list ) ] + # always set the last end as masked off because it causes issues + for i, mask in enumerate(half_mask_list): + half_mask_list[i][-75:] = self.stop_token + # + # mask off inputs per mask + resps_list = [ resp.scatter(0, indices, mask) for resp, indices, mask in zip( resps_list, masked_indices, half_mask_list ) ] # boolean mask - is_masked = [ resps == mask for resps, mask in zip( resps_list, mask_list ) ] + is_masked = [ resps == mask for resps, mask in zip( resps_list, half_mask_list ) ] # timestep inputs time_list = [ timestep for _ in range(batch_size) ] @@ -392,7 +403,7 @@ class AR_NAR(Base): # get sampled tokens sampled_ids = filtered_sampled.ids # keep unmasked tokens - resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ] + resps_list = [ torch.where( masked, input_ids, resps ).to(torch.int16) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ] # get probability scores scores = [ # conjugate to have worse scoring tokens picked for topk diff --git a/vall_e/webui.py b/vall_e/webui.py index 516852a..c03ff04 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -42,6 +42,7 @@ if USING_SPACES: from vall_e.emb.qnt import decode_to_wave from vall_e.data import get_lang_symmap, get_random_prompt from vall_e.models.arch import AVAILABLE_ATTENTIONS + from vall_e.emb.transcribe import transcribe else: from .inference import TTS, cfg from .train import train @@ -50,6 +51,8 @@ else: from .emb.qnt import decode_to_wave from .data import get_lang_symmap, get_random_prompt from .models.arch import AVAILABLE_ATTENTIONS + from .emb.transcribe import transcribe + is_windows = sys.platform.startswith("win") @@ -144,6 +147,11 @@ def load_sample( speaker ): return data, (sr, wav) +def gradio_transcribe_input( audio, text, split_by ): + if not audio: + return ( text, split_by ) + return ( transcribe( audio, model_name="openai/whisper-base", align=False )["text"], "lines" ) + def init_tts(config=None, lora=None, restart=False, device="cuda", dtype="auto", attention=None): global tts @@ -203,6 +211,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): parser.add_argument("--task", type=str, default="tts") parser.add_argument("--modality", type=str, default=kwargs["modality"]) parser.add_argument("--references", type=str, default=kwargs["reference"]) + parser.add_argument("--voice-convert", type=str, default=kwargs["voice-convert"]) parser.add_argument("--language", type=str, default=kwargs["language"]) parser.add_argument("--text-language", type=str, default=kwargs["text-language"]) parser.add_argument("--split-text-by", type=str, default=kwargs["split-text-by"]) @@ -275,6 +284,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): sampling_kwargs = dict( split_text_by=args.split_text_by, context_history=args.context_history, + voice_convert=args.voice_convert, max_steps=args.max_steps, max_levels=args.max_levels, max_duration=args.max_duration, @@ -391,6 +401,7 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): """ @gradio_wrapper(inputs=layout["training"]["inputs"].keys()) def do_training( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): + while True: metrics = next(it) yield metrics @@ -430,10 +441,13 @@ with ui: with gr.Tab("Text-to-Speech"): with gr.Row(): with gr.Column(scale=8): - layout["inference_tts"]["inputs"]["text"] = gr.Textbox(lines=5, value=get_random_prompt, label="Input Prompt") + with gr.Tab("Text"): + layout["inference_tts"]["inputs"]["text"] = gr.Textbox(lines=5, value=get_random_prompt, label="Input Prompt") + with gr.Tab("Speech"): + layout["inference_tts"]["inputs"]["voice-convert"] = gr.Audio(label="Audio Input", sources=["upload"], type="filepath") # , info="Guiding utternace.") with gr.Row(): with gr.Column(scale=1): - layout["inference_tts"]["inputs"]["reference"] = gr.Audio(label="Audio Input", sources=["upload"], type="filepath") #, info="Reference audio for TTS") + layout["inference_tts"]["inputs"]["reference"] = gr.Audio(label="Audio Input", sources=["upload"], type="filepath") # , info="Reference audio for TTS") # layout["inference_tts"]["stop"] = gr.Button(value="Stop") layout["inference_tts"]["outputs"]["output"] = gr.Audio(label="Output") layout["inference_tts"]["buttons"]["inference"] = gr.Button(value="Inference") @@ -496,6 +510,20 @@ with ui: outputs=[ x for x in layout["inference_tts"]["outputs"].values() if x is not None] ) + # IC + layout["inference_tts"]["inputs"]["voice-convert"].change( + gradio_transcribe_input, + [ + layout["inference_tts"]["inputs"]["voice-convert"], + layout["inference_tts"]["inputs"]["text"], + layout["inference_tts"]["inputs"]["split-text-by"], + ], + [ + layout["inference_tts"]["inputs"]["text"], + layout["inference_tts"]["inputs"]["split-text-by"], + ] + ) + with gr.Tab("Speech to Text"): with gr.Row(): with gr.Column(scale=8):