really shoddy voice conversion implementation (it sort of works...)

This commit is contained in:
mrq 2024-12-16 22:54:53 -06:00
parent 8515038968
commit c2e17e287b
4 changed files with 62 additions and 12 deletions

View File

@ -95,7 +95,7 @@ def main():
parser.add_argument("--out-path", type=Path, default=None) parser.add_argument("--out-path", type=Path, default=None)
parser.add_argument("--max-duration", type=int, default=12 * cfg.dataset.frames_per_second) parser.add_argument("--max-duration", type=int, default=12 * cfg.dataset.frames_per_second)
parser.add_argument("--max-steps", type=int, default=50) parser.add_argument("--max-steps", type=int, default=30)
parser.add_argument("--max-levels", type=int, default=7) parser.add_argument("--max-levels", type=int, default=7)
parser.add_argument("--ar-temperature", type=float, default=1.0) parser.add_argument("--ar-temperature", type=float, default=1.0)

View File

@ -13,6 +13,8 @@ from pathlib import Path
from .emb import g2p, qnt from .emb import g2p, qnt
from .emb.qnt import trim, trim_random, unload_model, repeat_extend_audio from .emb.qnt import trim, trim_random, unload_model, repeat_extend_audio
from .emb.transcribe import transcribe
from .utils import to_device, set_seed, clamp, wrapper as ml from .utils import to_device, set_seed, clamp, wrapper as ml
from .config import cfg, Config from .config import cfg, Config
@ -118,7 +120,7 @@ class TTS():
return torch.tensor([ id ]) return torch.tensor([ id ])
# to-do: trim before quantizing, instead of after # to-do: trim before quantizing, instead of after
def encode_audio( self, paths, trim_length=5.0 ): def encode_audio( self, paths, trim_length=0.0 ):
# already a tensor, return it # already a tensor, return it
if isinstance( paths, Tensor ): if isinstance( paths, Tensor ):
return paths return paths
@ -358,6 +360,12 @@ class TTS():
dtype = sampling_kwargs.pop("dtype", self.dtype) dtype = sampling_kwargs.pop("dtype", self.dtype)
amp = sampling_kwargs.pop("amp", self.amp) amp = sampling_kwargs.pop("amp", self.amp)
voice_convert = sampling_kwargs.pop("voice_convert", None)
# transcribe from audio to voice convert from
if voice_convert is not None and not text:
text = transcribe( voice_convert, model_name="openai/whisper-base", align=False )["text"]
lines = sentence_split(text, split_by=sampling_kwargs.get("split_text_by", "sentences")) lines = sentence_split(text, split_by=sampling_kwargs.get("split_text_by", "sentences"))
wavs = [] wavs = []
@ -430,6 +438,7 @@ class TTS():
if auto_text_lang: if auto_text_lang:
text_language = deduced_language text_language = deduced_language
vc_utterance = self.encode_audio( voice_convert, trim_length=0 ) if voice_convert else None
prom = self.encode_audio( references, trim_length=input_prompt_length ) if references else None prom = self.encode_audio( references, trim_length=input_prompt_length ) if references else None
phns = self.encode_text( line, language=text_language ) phns = self.encode_text( line, language=text_language )
lang = self.encode_lang( language ) lang = self.encode_lang( language )
@ -457,6 +466,8 @@ class TTS():
kwargs = {} kwargs = {}
if prefix_context is not None: if prefix_context is not None:
kwargs["prefix_context"] = prefix_context kwargs["prefix_context"] = prefix_context
if vc_utterance is not None:
kwargs["vc_list"] = [ vc_utterance ]
resps_list = model_nar( **input_kwargs, len_list=len_list, task_list=["tts"], resps_list = model_nar( **input_kwargs, len_list=len_list, task_list=["tts"],
**(sampling_kwargs | kwargs), **(sampling_kwargs | kwargs),

View File

@ -259,9 +259,12 @@ class AR_NAR(Base):
max_steps = math.floor(max_steps * (end_noise - start_noise)) max_steps = math.floor(max_steps * (end_noise - start_noise))
# to specify the initial mask used # to specify the initial mask used
mask_list = sampling_kwargs.pop("mask_list", None) vc_list = sampling_kwargs.pop("vc_list", None)
if mask_list is not None: vc_threshold = sampling_kwargs.pop("vc_threshold", 0.25)
len_list = [ x.shape[0] for x in mask_list ] vc_mask_p = sampling_kwargs.pop("vc_mask_p", 0.25)
if vc_list is not None:
vc_list = [ x if x.dim() == 1 else x[:, 0] for x in vc_list ]
len_list = [ x.shape[0] for x in vc_list ]
len_list = [ clamp(l, min_length, max_length) for l in len_list ] len_list = [ clamp(l, min_length, max_length) for l in len_list ]
@ -305,16 +308,24 @@ class AR_NAR(Base):
remask_p = 1.0 / (max_steps * 2) if remasking else 0 remask_p = 1.0 / (max_steps * 2) if remasking else 0
# pick the worst scoring tokens to mask off # pick the worst scoring tokens to mask off
masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ] masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ]
if mask_list is None: # normal masking
if vc_list is None or timestep >= vc_threshold:
# mask off inputs # mask off inputs
resps_list = [ resp.scatter(0, indices, self.stop_token) for resp, indices in zip( resps_list, masked_indices ) ] resps_list = [ resp.scatter(0, indices, self.stop_token) for resp, indices in zip( resps_list, masked_indices ) ]
# boolean mask # boolean mask
is_masked = [ resps == self.stop_token for resps in resps_list ] is_masked = [ resps == self.stop_token for resps in resps_list ]
else: else:
# mask off inputs # mask off a random portion of the target
resps_list = [ resp.scatter(0, indices, mask) for resp, indices, mask in zip( resps_list, masked_indices, mask_list ) ] rand_mask_list = [ torch.rand(mask.shape).to(device=device) < vc_mask_p for mask in vc_list ]
half_mask_list = [ torch.where( rand_mask, self.stop_token, mask.clone() ) for mask, rand_mask in zip( vc_list, rand_mask_list ) ]
# always set the last end as masked off because it causes issues
for i, mask in enumerate(half_mask_list):
half_mask_list[i][-75:] = self.stop_token
#
# mask off inputs per mask
resps_list = [ resp.scatter(0, indices, mask) for resp, indices, mask in zip( resps_list, masked_indices, half_mask_list ) ]
# boolean mask # boolean mask
is_masked = [ resps == mask for resps, mask in zip( resps_list, mask_list ) ] is_masked = [ resps == mask for resps, mask in zip( resps_list, half_mask_list ) ]
# timestep inputs # timestep inputs
time_list = [ timestep for _ in range(batch_size) ] time_list = [ timestep for _ in range(batch_size) ]
@ -392,7 +403,7 @@ class AR_NAR(Base):
# get sampled tokens # get sampled tokens
sampled_ids = filtered_sampled.ids sampled_ids = filtered_sampled.ids
# keep unmasked tokens # keep unmasked tokens
resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ] resps_list = [ torch.where( masked, input_ids, resps ).to(torch.int16) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
# get probability scores # get probability scores
scores = [ scores = [
# conjugate to have worse scoring tokens picked for topk # conjugate to have worse scoring tokens picked for topk

View File

@ -42,6 +42,7 @@ if USING_SPACES:
from vall_e.emb.qnt import decode_to_wave from vall_e.emb.qnt import decode_to_wave
from vall_e.data import get_lang_symmap, get_random_prompt from vall_e.data import get_lang_symmap, get_random_prompt
from vall_e.models.arch import AVAILABLE_ATTENTIONS from vall_e.models.arch import AVAILABLE_ATTENTIONS
from vall_e.emb.transcribe import transcribe
else: else:
from .inference import TTS, cfg from .inference import TTS, cfg
from .train import train from .train import train
@ -50,6 +51,8 @@ else:
from .emb.qnt import decode_to_wave from .emb.qnt import decode_to_wave
from .data import get_lang_symmap, get_random_prompt from .data import get_lang_symmap, get_random_prompt
from .models.arch import AVAILABLE_ATTENTIONS from .models.arch import AVAILABLE_ATTENTIONS
from .emb.transcribe import transcribe
is_windows = sys.platform.startswith("win") is_windows = sys.platform.startswith("win")
@ -144,6 +147,11 @@ def load_sample( speaker ):
return data, (sr, wav) return data, (sr, wav)
def gradio_transcribe_input( audio, text, split_by ):
if not audio:
return ( text, split_by )
return ( transcribe( audio, model_name="openai/whisper-base", align=False )["text"], "lines" )
def init_tts(config=None, lora=None, restart=False, device="cuda", dtype="auto", attention=None): def init_tts(config=None, lora=None, restart=False, device="cuda", dtype="auto", attention=None):
global tts global tts
@ -203,6 +211,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--task", type=str, default="tts") parser.add_argument("--task", type=str, default="tts")
parser.add_argument("--modality", type=str, default=kwargs["modality"]) parser.add_argument("--modality", type=str, default=kwargs["modality"])
parser.add_argument("--references", type=str, default=kwargs["reference"]) parser.add_argument("--references", type=str, default=kwargs["reference"])
parser.add_argument("--voice-convert", type=str, default=kwargs["voice-convert"])
parser.add_argument("--language", type=str, default=kwargs["language"]) parser.add_argument("--language", type=str, default=kwargs["language"])
parser.add_argument("--text-language", type=str, default=kwargs["text-language"]) parser.add_argument("--text-language", type=str, default=kwargs["text-language"])
parser.add_argument("--split-text-by", type=str, default=kwargs["split-text-by"]) parser.add_argument("--split-text-by", type=str, default=kwargs["split-text-by"])
@ -275,6 +284,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
sampling_kwargs = dict( sampling_kwargs = dict(
split_text_by=args.split_text_by, split_text_by=args.split_text_by,
context_history=args.context_history, context_history=args.context_history,
voice_convert=args.voice_convert,
max_steps=args.max_steps, max_steps=args.max_steps,
max_levels=args.max_levels, max_levels=args.max_levels,
max_duration=args.max_duration, max_duration=args.max_duration,
@ -391,6 +401,7 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
""" """
@gradio_wrapper(inputs=layout["training"]["inputs"].keys()) @gradio_wrapper(inputs=layout["training"]["inputs"].keys())
def do_training( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): def do_training( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
while True: while True:
metrics = next(it) metrics = next(it)
yield metrics yield metrics
@ -430,10 +441,13 @@ with ui:
with gr.Tab("Text-to-Speech"): with gr.Tab("Text-to-Speech"):
with gr.Row(): with gr.Row():
with gr.Column(scale=8): with gr.Column(scale=8):
with gr.Tab("Text"):
layout["inference_tts"]["inputs"]["text"] = gr.Textbox(lines=5, value=get_random_prompt, label="Input Prompt") layout["inference_tts"]["inputs"]["text"] = gr.Textbox(lines=5, value=get_random_prompt, label="Input Prompt")
with gr.Tab("Speech"):
layout["inference_tts"]["inputs"]["voice-convert"] = gr.Audio(label="Audio Input", sources=["upload"], type="filepath") # , info="Guiding utternace.")
with gr.Row(): with gr.Row():
with gr.Column(scale=1): with gr.Column(scale=1):
layout["inference_tts"]["inputs"]["reference"] = gr.Audio(label="Audio Input", sources=["upload"], type="filepath") #, info="Reference audio for TTS") layout["inference_tts"]["inputs"]["reference"] = gr.Audio(label="Audio Input", sources=["upload"], type="filepath") # , info="Reference audio for TTS")
# layout["inference_tts"]["stop"] = gr.Button(value="Stop") # layout["inference_tts"]["stop"] = gr.Button(value="Stop")
layout["inference_tts"]["outputs"]["output"] = gr.Audio(label="Output") layout["inference_tts"]["outputs"]["output"] = gr.Audio(label="Output")
layout["inference_tts"]["buttons"]["inference"] = gr.Button(value="Inference") layout["inference_tts"]["buttons"]["inference"] = gr.Button(value="Inference")
@ -496,6 +510,20 @@ with ui:
outputs=[ x for x in layout["inference_tts"]["outputs"].values() if x is not None] outputs=[ x for x in layout["inference_tts"]["outputs"].values() if x is not None]
) )
# IC
layout["inference_tts"]["inputs"]["voice-convert"].change(
gradio_transcribe_input,
[
layout["inference_tts"]["inputs"]["voice-convert"],
layout["inference_tts"]["inputs"]["text"],
layout["inference_tts"]["inputs"]["split-text-by"],
],
[
layout["inference_tts"]["inputs"]["text"],
layout["inference_tts"]["inputs"]["split-text-by"],
]
)
with gr.Tab("Speech to Text"): with gr.Tab("Speech to Text"):
with gr.Row(): with gr.Row():
with gr.Column(scale=8): with gr.Column(scale=8):