really shoddy voice conversion implementation (it sort of works...)
This commit is contained in:
parent
8515038968
commit
c2e17e287b
|
@ -95,7 +95,7 @@ def main():
|
||||||
parser.add_argument("--out-path", type=Path, default=None)
|
parser.add_argument("--out-path", type=Path, default=None)
|
||||||
|
|
||||||
parser.add_argument("--max-duration", type=int, default=12 * cfg.dataset.frames_per_second)
|
parser.add_argument("--max-duration", type=int, default=12 * cfg.dataset.frames_per_second)
|
||||||
parser.add_argument("--max-steps", type=int, default=50)
|
parser.add_argument("--max-steps", type=int, default=30)
|
||||||
parser.add_argument("--max-levels", type=int, default=7)
|
parser.add_argument("--max-levels", type=int, default=7)
|
||||||
|
|
||||||
parser.add_argument("--ar-temperature", type=float, default=1.0)
|
parser.add_argument("--ar-temperature", type=float, default=1.0)
|
||||||
|
|
|
@ -13,6 +13,8 @@ from pathlib import Path
|
||||||
|
|
||||||
from .emb import g2p, qnt
|
from .emb import g2p, qnt
|
||||||
from .emb.qnt import trim, trim_random, unload_model, repeat_extend_audio
|
from .emb.qnt import trim, trim_random, unload_model, repeat_extend_audio
|
||||||
|
from .emb.transcribe import transcribe
|
||||||
|
|
||||||
from .utils import to_device, set_seed, clamp, wrapper as ml
|
from .utils import to_device, set_seed, clamp, wrapper as ml
|
||||||
|
|
||||||
from .config import cfg, Config
|
from .config import cfg, Config
|
||||||
|
@ -118,7 +120,7 @@ class TTS():
|
||||||
return torch.tensor([ id ])
|
return torch.tensor([ id ])
|
||||||
|
|
||||||
# to-do: trim before quantizing, instead of after
|
# to-do: trim before quantizing, instead of after
|
||||||
def encode_audio( self, paths, trim_length=5.0 ):
|
def encode_audio( self, paths, trim_length=0.0 ):
|
||||||
# already a tensor, return it
|
# already a tensor, return it
|
||||||
if isinstance( paths, Tensor ):
|
if isinstance( paths, Tensor ):
|
||||||
return paths
|
return paths
|
||||||
|
@ -358,6 +360,12 @@ class TTS():
|
||||||
dtype = sampling_kwargs.pop("dtype", self.dtype)
|
dtype = sampling_kwargs.pop("dtype", self.dtype)
|
||||||
amp = sampling_kwargs.pop("amp", self.amp)
|
amp = sampling_kwargs.pop("amp", self.amp)
|
||||||
|
|
||||||
|
voice_convert = sampling_kwargs.pop("voice_convert", None)
|
||||||
|
|
||||||
|
# transcribe from audio to voice convert from
|
||||||
|
if voice_convert is not None and not text:
|
||||||
|
text = transcribe( voice_convert, model_name="openai/whisper-base", align=False )["text"]
|
||||||
|
|
||||||
lines = sentence_split(text, split_by=sampling_kwargs.get("split_text_by", "sentences"))
|
lines = sentence_split(text, split_by=sampling_kwargs.get("split_text_by", "sentences"))
|
||||||
|
|
||||||
wavs = []
|
wavs = []
|
||||||
|
@ -430,6 +438,7 @@ class TTS():
|
||||||
if auto_text_lang:
|
if auto_text_lang:
|
||||||
text_language = deduced_language
|
text_language = deduced_language
|
||||||
|
|
||||||
|
vc_utterance = self.encode_audio( voice_convert, trim_length=0 ) if voice_convert else None
|
||||||
prom = self.encode_audio( references, trim_length=input_prompt_length ) if references else None
|
prom = self.encode_audio( references, trim_length=input_prompt_length ) if references else None
|
||||||
phns = self.encode_text( line, language=text_language )
|
phns = self.encode_text( line, language=text_language )
|
||||||
lang = self.encode_lang( language )
|
lang = self.encode_lang( language )
|
||||||
|
@ -457,6 +466,8 @@ class TTS():
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if prefix_context is not None:
|
if prefix_context is not None:
|
||||||
kwargs["prefix_context"] = prefix_context
|
kwargs["prefix_context"] = prefix_context
|
||||||
|
if vc_utterance is not None:
|
||||||
|
kwargs["vc_list"] = [ vc_utterance ]
|
||||||
|
|
||||||
resps_list = model_nar( **input_kwargs, len_list=len_list, task_list=["tts"],
|
resps_list = model_nar( **input_kwargs, len_list=len_list, task_list=["tts"],
|
||||||
**(sampling_kwargs | kwargs),
|
**(sampling_kwargs | kwargs),
|
||||||
|
|
|
@ -259,9 +259,12 @@ class AR_NAR(Base):
|
||||||
max_steps = math.floor(max_steps * (end_noise - start_noise))
|
max_steps = math.floor(max_steps * (end_noise - start_noise))
|
||||||
|
|
||||||
# to specify the initial mask used
|
# to specify the initial mask used
|
||||||
mask_list = sampling_kwargs.pop("mask_list", None)
|
vc_list = sampling_kwargs.pop("vc_list", None)
|
||||||
if mask_list is not None:
|
vc_threshold = sampling_kwargs.pop("vc_threshold", 0.25)
|
||||||
len_list = [ x.shape[0] for x in mask_list ]
|
vc_mask_p = sampling_kwargs.pop("vc_mask_p", 0.25)
|
||||||
|
if vc_list is not None:
|
||||||
|
vc_list = [ x if x.dim() == 1 else x[:, 0] for x in vc_list ]
|
||||||
|
len_list = [ x.shape[0] for x in vc_list ]
|
||||||
|
|
||||||
len_list = [ clamp(l, min_length, max_length) for l in len_list ]
|
len_list = [ clamp(l, min_length, max_length) for l in len_list ]
|
||||||
|
|
||||||
|
@ -305,16 +308,24 @@ class AR_NAR(Base):
|
||||||
remask_p = 1.0 / (max_steps * 2) if remasking else 0
|
remask_p = 1.0 / (max_steps * 2) if remasking else 0
|
||||||
# pick the worst scoring tokens to mask off
|
# pick the worst scoring tokens to mask off
|
||||||
masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ]
|
masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ]
|
||||||
if mask_list is None:
|
# normal masking
|
||||||
|
if vc_list is None or timestep >= vc_threshold:
|
||||||
# mask off inputs
|
# mask off inputs
|
||||||
resps_list = [ resp.scatter(0, indices, self.stop_token) for resp, indices in zip( resps_list, masked_indices ) ]
|
resps_list = [ resp.scatter(0, indices, self.stop_token) for resp, indices in zip( resps_list, masked_indices ) ]
|
||||||
# boolean mask
|
# boolean mask
|
||||||
is_masked = [ resps == self.stop_token for resps in resps_list ]
|
is_masked = [ resps == self.stop_token for resps in resps_list ]
|
||||||
else:
|
else:
|
||||||
# mask off inputs
|
# mask off a random portion of the target
|
||||||
resps_list = [ resp.scatter(0, indices, mask) for resp, indices, mask in zip( resps_list, masked_indices, mask_list ) ]
|
rand_mask_list = [ torch.rand(mask.shape).to(device=device) < vc_mask_p for mask in vc_list ]
|
||||||
|
half_mask_list = [ torch.where( rand_mask, self.stop_token, mask.clone() ) for mask, rand_mask in zip( vc_list, rand_mask_list ) ]
|
||||||
|
# always set the last end as masked off because it causes issues
|
||||||
|
for i, mask in enumerate(half_mask_list):
|
||||||
|
half_mask_list[i][-75:] = self.stop_token
|
||||||
|
#
|
||||||
|
# mask off inputs per mask
|
||||||
|
resps_list = [ resp.scatter(0, indices, mask) for resp, indices, mask in zip( resps_list, masked_indices, half_mask_list ) ]
|
||||||
# boolean mask
|
# boolean mask
|
||||||
is_masked = [ resps == mask for resps, mask in zip( resps_list, mask_list ) ]
|
is_masked = [ resps == mask for resps, mask in zip( resps_list, half_mask_list ) ]
|
||||||
|
|
||||||
# timestep inputs
|
# timestep inputs
|
||||||
time_list = [ timestep for _ in range(batch_size) ]
|
time_list = [ timestep for _ in range(batch_size) ]
|
||||||
|
@ -392,7 +403,7 @@ class AR_NAR(Base):
|
||||||
# get sampled tokens
|
# get sampled tokens
|
||||||
sampled_ids = filtered_sampled.ids
|
sampled_ids = filtered_sampled.ids
|
||||||
# keep unmasked tokens
|
# keep unmasked tokens
|
||||||
resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
|
resps_list = [ torch.where( masked, input_ids, resps ).to(torch.int16) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
|
||||||
# get probability scores
|
# get probability scores
|
||||||
scores = [
|
scores = [
|
||||||
# conjugate to have worse scoring tokens picked for topk
|
# conjugate to have worse scoring tokens picked for topk
|
||||||
|
|
|
@ -42,6 +42,7 @@ if USING_SPACES:
|
||||||
from vall_e.emb.qnt import decode_to_wave
|
from vall_e.emb.qnt import decode_to_wave
|
||||||
from vall_e.data import get_lang_symmap, get_random_prompt
|
from vall_e.data import get_lang_symmap, get_random_prompt
|
||||||
from vall_e.models.arch import AVAILABLE_ATTENTIONS
|
from vall_e.models.arch import AVAILABLE_ATTENTIONS
|
||||||
|
from vall_e.emb.transcribe import transcribe
|
||||||
else:
|
else:
|
||||||
from .inference import TTS, cfg
|
from .inference import TTS, cfg
|
||||||
from .train import train
|
from .train import train
|
||||||
|
@ -50,6 +51,8 @@ else:
|
||||||
from .emb.qnt import decode_to_wave
|
from .emb.qnt import decode_to_wave
|
||||||
from .data import get_lang_symmap, get_random_prompt
|
from .data import get_lang_symmap, get_random_prompt
|
||||||
from .models.arch import AVAILABLE_ATTENTIONS
|
from .models.arch import AVAILABLE_ATTENTIONS
|
||||||
|
from .emb.transcribe import transcribe
|
||||||
|
|
||||||
|
|
||||||
is_windows = sys.platform.startswith("win")
|
is_windows = sys.platform.startswith("win")
|
||||||
|
|
||||||
|
@ -144,6 +147,11 @@ def load_sample( speaker ):
|
||||||
|
|
||||||
return data, (sr, wav)
|
return data, (sr, wav)
|
||||||
|
|
||||||
|
def gradio_transcribe_input( audio, text, split_by ):
|
||||||
|
if not audio:
|
||||||
|
return ( text, split_by )
|
||||||
|
return ( transcribe( audio, model_name="openai/whisper-base", align=False )["text"], "lines" )
|
||||||
|
|
||||||
def init_tts(config=None, lora=None, restart=False, device="cuda", dtype="auto", attention=None):
|
def init_tts(config=None, lora=None, restart=False, device="cuda", dtype="auto", attention=None):
|
||||||
global tts
|
global tts
|
||||||
|
|
||||||
|
@ -203,6 +211,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
parser.add_argument("--task", type=str, default="tts")
|
parser.add_argument("--task", type=str, default="tts")
|
||||||
parser.add_argument("--modality", type=str, default=kwargs["modality"])
|
parser.add_argument("--modality", type=str, default=kwargs["modality"])
|
||||||
parser.add_argument("--references", type=str, default=kwargs["reference"])
|
parser.add_argument("--references", type=str, default=kwargs["reference"])
|
||||||
|
parser.add_argument("--voice-convert", type=str, default=kwargs["voice-convert"])
|
||||||
parser.add_argument("--language", type=str, default=kwargs["language"])
|
parser.add_argument("--language", type=str, default=kwargs["language"])
|
||||||
parser.add_argument("--text-language", type=str, default=kwargs["text-language"])
|
parser.add_argument("--text-language", type=str, default=kwargs["text-language"])
|
||||||
parser.add_argument("--split-text-by", type=str, default=kwargs["split-text-by"])
|
parser.add_argument("--split-text-by", type=str, default=kwargs["split-text-by"])
|
||||||
|
@ -275,6 +284,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
sampling_kwargs = dict(
|
sampling_kwargs = dict(
|
||||||
split_text_by=args.split_text_by,
|
split_text_by=args.split_text_by,
|
||||||
context_history=args.context_history,
|
context_history=args.context_history,
|
||||||
|
voice_convert=args.voice_convert,
|
||||||
max_steps=args.max_steps,
|
max_steps=args.max_steps,
|
||||||
max_levels=args.max_levels,
|
max_levels=args.max_levels,
|
||||||
max_duration=args.max_duration,
|
max_duration=args.max_duration,
|
||||||
|
@ -391,6 +401,7 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
"""
|
"""
|
||||||
@gradio_wrapper(inputs=layout["training"]["inputs"].keys())
|
@gradio_wrapper(inputs=layout["training"]["inputs"].keys())
|
||||||
def do_training( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
def do_training( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
metrics = next(it)
|
metrics = next(it)
|
||||||
yield metrics
|
yield metrics
|
||||||
|
@ -430,7 +441,10 @@ with ui:
|
||||||
with gr.Tab("Text-to-Speech"):
|
with gr.Tab("Text-to-Speech"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=8):
|
with gr.Column(scale=8):
|
||||||
|
with gr.Tab("Text"):
|
||||||
layout["inference_tts"]["inputs"]["text"] = gr.Textbox(lines=5, value=get_random_prompt, label="Input Prompt")
|
layout["inference_tts"]["inputs"]["text"] = gr.Textbox(lines=5, value=get_random_prompt, label="Input Prompt")
|
||||||
|
with gr.Tab("Speech"):
|
||||||
|
layout["inference_tts"]["inputs"]["voice-convert"] = gr.Audio(label="Audio Input", sources=["upload"], type="filepath") # , info="Guiding utternace.")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
layout["inference_tts"]["inputs"]["reference"] = gr.Audio(label="Audio Input", sources=["upload"], type="filepath") # , info="Reference audio for TTS")
|
layout["inference_tts"]["inputs"]["reference"] = gr.Audio(label="Audio Input", sources=["upload"], type="filepath") # , info="Reference audio for TTS")
|
||||||
|
@ -496,6 +510,20 @@ with ui:
|
||||||
outputs=[ x for x in layout["inference_tts"]["outputs"].values() if x is not None]
|
outputs=[ x for x in layout["inference_tts"]["outputs"].values() if x is not None]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# IC
|
||||||
|
layout["inference_tts"]["inputs"]["voice-convert"].change(
|
||||||
|
gradio_transcribe_input,
|
||||||
|
[
|
||||||
|
layout["inference_tts"]["inputs"]["voice-convert"],
|
||||||
|
layout["inference_tts"]["inputs"]["text"],
|
||||||
|
layout["inference_tts"]["inputs"]["split-text-by"],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
layout["inference_tts"]["inputs"]["text"],
|
||||||
|
layout["inference_tts"]["inputs"]["split-text-by"],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Tab("Speech to Text"):
|
with gr.Tab("Speech to Text"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=8):
|
with gr.Column(scale=8):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user