really shoddy voice conversion implementation (it sort of works...)
This commit is contained in:
parent
8515038968
commit
c2e17e287b
|
@ -95,7 +95,7 @@ def main():
|
|||
parser.add_argument("--out-path", type=Path, default=None)
|
||||
|
||||
parser.add_argument("--max-duration", type=int, default=12 * cfg.dataset.frames_per_second)
|
||||
parser.add_argument("--max-steps", type=int, default=50)
|
||||
parser.add_argument("--max-steps", type=int, default=30)
|
||||
parser.add_argument("--max-levels", type=int, default=7)
|
||||
|
||||
parser.add_argument("--ar-temperature", type=float, default=1.0)
|
||||
|
|
|
@ -13,6 +13,8 @@ from pathlib import Path
|
|||
|
||||
from .emb import g2p, qnt
|
||||
from .emb.qnt import trim, trim_random, unload_model, repeat_extend_audio
|
||||
from .emb.transcribe import transcribe
|
||||
|
||||
from .utils import to_device, set_seed, clamp, wrapper as ml
|
||||
|
||||
from .config import cfg, Config
|
||||
|
@ -118,7 +120,7 @@ class TTS():
|
|||
return torch.tensor([ id ])
|
||||
|
||||
# to-do: trim before quantizing, instead of after
|
||||
def encode_audio( self, paths, trim_length=5.0 ):
|
||||
def encode_audio( self, paths, trim_length=0.0 ):
|
||||
# already a tensor, return it
|
||||
if isinstance( paths, Tensor ):
|
||||
return paths
|
||||
|
@ -357,6 +359,12 @@ class TTS():
|
|||
use_lora = sampling_kwargs.pop("use_lora", None)
|
||||
dtype = sampling_kwargs.pop("dtype", self.dtype)
|
||||
amp = sampling_kwargs.pop("amp", self.amp)
|
||||
|
||||
voice_convert = sampling_kwargs.pop("voice_convert", None)
|
||||
|
||||
# transcribe from audio to voice convert from
|
||||
if voice_convert is not None and not text:
|
||||
text = transcribe( voice_convert, model_name="openai/whisper-base", align=False )["text"]
|
||||
|
||||
lines = sentence_split(text, split_by=sampling_kwargs.get("split_text_by", "sentences"))
|
||||
|
||||
|
@ -430,6 +438,7 @@ class TTS():
|
|||
if auto_text_lang:
|
||||
text_language = deduced_language
|
||||
|
||||
vc_utterance = self.encode_audio( voice_convert, trim_length=0 ) if voice_convert else None
|
||||
prom = self.encode_audio( references, trim_length=input_prompt_length ) if references else None
|
||||
phns = self.encode_text( line, language=text_language )
|
||||
lang = self.encode_lang( language )
|
||||
|
@ -457,6 +466,8 @@ class TTS():
|
|||
kwargs = {}
|
||||
if prefix_context is not None:
|
||||
kwargs["prefix_context"] = prefix_context
|
||||
if vc_utterance is not None:
|
||||
kwargs["vc_list"] = [ vc_utterance ]
|
||||
|
||||
resps_list = model_nar( **input_kwargs, len_list=len_list, task_list=["tts"],
|
||||
**(sampling_kwargs | kwargs),
|
||||
|
|
|
@ -259,9 +259,12 @@ class AR_NAR(Base):
|
|||
max_steps = math.floor(max_steps * (end_noise - start_noise))
|
||||
|
||||
# to specify the initial mask used
|
||||
mask_list = sampling_kwargs.pop("mask_list", None)
|
||||
if mask_list is not None:
|
||||
len_list = [ x.shape[0] for x in mask_list ]
|
||||
vc_list = sampling_kwargs.pop("vc_list", None)
|
||||
vc_threshold = sampling_kwargs.pop("vc_threshold", 0.25)
|
||||
vc_mask_p = sampling_kwargs.pop("vc_mask_p", 0.25)
|
||||
if vc_list is not None:
|
||||
vc_list = [ x if x.dim() == 1 else x[:, 0] for x in vc_list ]
|
||||
len_list = [ x.shape[0] for x in vc_list ]
|
||||
|
||||
len_list = [ clamp(l, min_length, max_length) for l in len_list ]
|
||||
|
||||
|
@ -305,16 +308,24 @@ class AR_NAR(Base):
|
|||
remask_p = 1.0 / (max_steps * 2) if remasking else 0
|
||||
# pick the worst scoring tokens to mask off
|
||||
masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ]
|
||||
if mask_list is None:
|
||||
# normal masking
|
||||
if vc_list is None or timestep >= vc_threshold:
|
||||
# mask off inputs
|
||||
resps_list = [ resp.scatter(0, indices, self.stop_token) for resp, indices in zip( resps_list, masked_indices ) ]
|
||||
# boolean mask
|
||||
is_masked = [ resps == self.stop_token for resps in resps_list ]
|
||||
else:
|
||||
# mask off inputs
|
||||
resps_list = [ resp.scatter(0, indices, mask) for resp, indices, mask in zip( resps_list, masked_indices, mask_list ) ]
|
||||
# mask off a random portion of the target
|
||||
rand_mask_list = [ torch.rand(mask.shape).to(device=device) < vc_mask_p for mask in vc_list ]
|
||||
half_mask_list = [ torch.where( rand_mask, self.stop_token, mask.clone() ) for mask, rand_mask in zip( vc_list, rand_mask_list ) ]
|
||||
# always set the last end as masked off because it causes issues
|
||||
for i, mask in enumerate(half_mask_list):
|
||||
half_mask_list[i][-75:] = self.stop_token
|
||||
#
|
||||
# mask off inputs per mask
|
||||
resps_list = [ resp.scatter(0, indices, mask) for resp, indices, mask in zip( resps_list, masked_indices, half_mask_list ) ]
|
||||
# boolean mask
|
||||
is_masked = [ resps == mask for resps, mask in zip( resps_list, mask_list ) ]
|
||||
is_masked = [ resps == mask for resps, mask in zip( resps_list, half_mask_list ) ]
|
||||
|
||||
# timestep inputs
|
||||
time_list = [ timestep for _ in range(batch_size) ]
|
||||
|
@ -392,7 +403,7 @@ class AR_NAR(Base):
|
|||
# get sampled tokens
|
||||
sampled_ids = filtered_sampled.ids
|
||||
# keep unmasked tokens
|
||||
resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
|
||||
resps_list = [ torch.where( masked, input_ids, resps ).to(torch.int16) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
|
||||
# get probability scores
|
||||
scores = [
|
||||
# conjugate to have worse scoring tokens picked for topk
|
||||
|
|
|
@ -42,6 +42,7 @@ if USING_SPACES:
|
|||
from vall_e.emb.qnt import decode_to_wave
|
||||
from vall_e.data import get_lang_symmap, get_random_prompt
|
||||
from vall_e.models.arch import AVAILABLE_ATTENTIONS
|
||||
from vall_e.emb.transcribe import transcribe
|
||||
else:
|
||||
from .inference import TTS, cfg
|
||||
from .train import train
|
||||
|
@ -50,6 +51,8 @@ else:
|
|||
from .emb.qnt import decode_to_wave
|
||||
from .data import get_lang_symmap, get_random_prompt
|
||||
from .models.arch import AVAILABLE_ATTENTIONS
|
||||
from .emb.transcribe import transcribe
|
||||
|
||||
|
||||
is_windows = sys.platform.startswith("win")
|
||||
|
||||
|
@ -144,6 +147,11 @@ def load_sample( speaker ):
|
|||
|
||||
return data, (sr, wav)
|
||||
|
||||
def gradio_transcribe_input( audio, text, split_by ):
|
||||
if not audio:
|
||||
return ( text, split_by )
|
||||
return ( transcribe( audio, model_name="openai/whisper-base", align=False )["text"], "lines" )
|
||||
|
||||
def init_tts(config=None, lora=None, restart=False, device="cuda", dtype="auto", attention=None):
|
||||
global tts
|
||||
|
||||
|
@ -203,6 +211,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
parser.add_argument("--task", type=str, default="tts")
|
||||
parser.add_argument("--modality", type=str, default=kwargs["modality"])
|
||||
parser.add_argument("--references", type=str, default=kwargs["reference"])
|
||||
parser.add_argument("--voice-convert", type=str, default=kwargs["voice-convert"])
|
||||
parser.add_argument("--language", type=str, default=kwargs["language"])
|
||||
parser.add_argument("--text-language", type=str, default=kwargs["text-language"])
|
||||
parser.add_argument("--split-text-by", type=str, default=kwargs["split-text-by"])
|
||||
|
@ -275,6 +284,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
sampling_kwargs = dict(
|
||||
split_text_by=args.split_text_by,
|
||||
context_history=args.context_history,
|
||||
voice_convert=args.voice_convert,
|
||||
max_steps=args.max_steps,
|
||||
max_levels=args.max_levels,
|
||||
max_duration=args.max_duration,
|
||||
|
@ -391,6 +401,7 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
"""
|
||||
@gradio_wrapper(inputs=layout["training"]["inputs"].keys())
|
||||
def do_training( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||
|
||||
while True:
|
||||
metrics = next(it)
|
||||
yield metrics
|
||||
|
@ -430,10 +441,13 @@ with ui:
|
|||
with gr.Tab("Text-to-Speech"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=8):
|
||||
layout["inference_tts"]["inputs"]["text"] = gr.Textbox(lines=5, value=get_random_prompt, label="Input Prompt")
|
||||
with gr.Tab("Text"):
|
||||
layout["inference_tts"]["inputs"]["text"] = gr.Textbox(lines=5, value=get_random_prompt, label="Input Prompt")
|
||||
with gr.Tab("Speech"):
|
||||
layout["inference_tts"]["inputs"]["voice-convert"] = gr.Audio(label="Audio Input", sources=["upload"], type="filepath") # , info="Guiding utternace.")
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
layout["inference_tts"]["inputs"]["reference"] = gr.Audio(label="Audio Input", sources=["upload"], type="filepath") #, info="Reference audio for TTS")
|
||||
layout["inference_tts"]["inputs"]["reference"] = gr.Audio(label="Audio Input", sources=["upload"], type="filepath") # , info="Reference audio for TTS")
|
||||
# layout["inference_tts"]["stop"] = gr.Button(value="Stop")
|
||||
layout["inference_tts"]["outputs"]["output"] = gr.Audio(label="Output")
|
||||
layout["inference_tts"]["buttons"]["inference"] = gr.Button(value="Inference")
|
||||
|
@ -496,6 +510,20 @@ with ui:
|
|||
outputs=[ x for x in layout["inference_tts"]["outputs"].values() if x is not None]
|
||||
)
|
||||
|
||||
# IC
|
||||
layout["inference_tts"]["inputs"]["voice-convert"].change(
|
||||
gradio_transcribe_input,
|
||||
[
|
||||
layout["inference_tts"]["inputs"]["voice-convert"],
|
||||
layout["inference_tts"]["inputs"]["text"],
|
||||
layout["inference_tts"]["inputs"]["split-text-by"],
|
||||
],
|
||||
[
|
||||
layout["inference_tts"]["inputs"]["text"],
|
||||
layout["inference_tts"]["inputs"]["split-text-by"],
|
||||
]
|
||||
)
|
||||
|
||||
with gr.Tab("Speech to Text"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=8):
|
||||
|
|
Loading…
Reference in New Issue
Block a user