fixed it breaking on subsequent utterances through the web UI from latents being on the CPU
This commit is contained in:
parent
96b74f38ef
commit
f4fcc35aa8
|
@ -13,8 +13,8 @@ def main():
|
|||
parser.add_argument("--out-path", type=Path, default=None)
|
||||
parser.add_argument("--max-ar-steps", type=int, default=500)
|
||||
parser.add_argument("--max-diffusion-steps", type=int, default=80)
|
||||
parser.add_argument("--ar-temp", type=float, default=1.0)
|
||||
parser.add_argument("--diffusion-temp", type=float, default=0.01)
|
||||
parser.add_argument("--ar-temp", type=float, default=0.8)
|
||||
parser.add_argument("--diffusion-temp", type=float, default=1.0)
|
||||
parser.add_argument("--top-p", type=float, default=1.0)
|
||||
parser.add_argument("--top-k", type=int, default=16)
|
||||
parser.add_argument("--repetition-penalty", type=float, default=1.0)
|
||||
|
|
|
@ -103,7 +103,7 @@ class TTS():
|
|||
max_diffusion_steps=80,
|
||||
#max_ar_context=-1,
|
||||
#input_prompt_length=0.0,
|
||||
ar_temp=1.0,
|
||||
ar_temp=0.8,
|
||||
diffusion_temp=1.0,
|
||||
#min_ar_temp=0.95,
|
||||
#min_diffusion_temp=0.5,
|
||||
|
@ -131,8 +131,6 @@ class TTS():
|
|||
clvp = None
|
||||
vocoder = None
|
||||
diffuser = get_diffuser(steps=max_diffusion_steps, cond_free=cond_free)
|
||||
|
||||
autoregressive_latents, diffusion_latents = self.encode_audio( references )["latent"]
|
||||
|
||||
for name, engine in self.engines.items():
|
||||
if "autoregressive" in name:
|
||||
|
@ -152,6 +150,10 @@ class TTS():
|
|||
clvp = load_model("clvp", device=cfg.device)
|
||||
if vocoder is None:
|
||||
vocoder = load_model("vocoder", device=cfg.device)
|
||||
|
||||
autoregressive = autoregressive.to(cfg.device)
|
||||
diffusion = diffusion.to(cfg.device)
|
||||
autoregressive_latents, diffusion_latents = self.encode_audio( references )["latent"]
|
||||
|
||||
# shove everything to cpu
|
||||
if cfg.inference.auto_unload:
|
||||
|
@ -164,6 +166,8 @@ class TTS():
|
|||
# other vars
|
||||
calm_token = 832
|
||||
|
||||
candidates = 1
|
||||
|
||||
for line in lines:
|
||||
if out_path is None:
|
||||
output_dir = Path("./data/results/")
|
||||
|
@ -185,7 +189,7 @@ class TTS():
|
|||
do_sample=True,
|
||||
top_p=top_p,
|
||||
temperature=ar_temp,
|
||||
num_return_sequences=1,
|
||||
num_return_sequences=candidates,
|
||||
length_penalty=length_penalty,
|
||||
repetition_penalty=repetition_penalty,
|
||||
max_generate_length=max_ar_steps,
|
||||
|
@ -213,12 +217,13 @@ class TTS():
|
|||
|
||||
wav_lengths = torch.tensor([codes.shape[-1] * autoregressive.mel_length_compression], device=text_tokens.device)
|
||||
|
||||
# to-do: actually move this after the CLVP to get the best latents instead
|
||||
latents = autoregressive.forward(
|
||||
autoregressive_latents,
|
||||
text_tokens,
|
||||
text_lengths,
|
||||
autoregressive_latents if candidates <= 1 else autoregressive_latents.repeat(candidates, 1),
|
||||
text_tokens if candidates <= 1 else text_tokens.repeat(candidates, 1),
|
||||
text_lengths if candidates <= 1 else text_lengths.repeat(candidates, 1),
|
||||
codes,
|
||||
wav_lengths,
|
||||
wav_lengths if candidates <= 1 else wav_lengths.repeat(candidates, 1),
|
||||
return_latent=True,
|
||||
clip_inputs=False
|
||||
)
|
||||
|
@ -233,6 +238,13 @@ class TTS():
|
|||
latents = latents[:, :k]
|
||||
break
|
||||
|
||||
# clvp pass
|
||||
if candidates > 1:
|
||||
with ml.auto_unload(clvp, enabled=cfg.inference.auto_unload):
|
||||
scores = clvp(text_tokens.repeat(codes.shape[0], 1), codes, return_loss=False)
|
||||
indices = torch.topk(scores, k=candidates).indices
|
||||
codes = codes[indices]
|
||||
|
||||
# diffusion pass
|
||||
with ml.auto_unload(diffusion, enabled=cfg.inference.auto_unload):
|
||||
output_seq_len = latents.shape[1] * 4 * 24000 // 22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
|
||||
|
@ -258,6 +270,6 @@ class TTS():
|
|||
if out_path is not None:
|
||||
torchaudio.save( out_path, wav.cpu(), sr )
|
||||
wavs.append(wav)
|
||||
|
||||
|
||||
return (torch.concat(wavs, dim=-1), sr)
|
||||
|
||||
|
|
|
@ -219,8 +219,8 @@ with ui:
|
|||
layout["inference"]["inputs"]["diffusion-sampler"] = gr.Radio( ["P", "DDIM"], value="DDIM", label="Diffusion Samplers", type="value", info="Sampler to use during the diffusion pass." )
|
||||
layout["inference"]["inputs"]["cond-free"] = gr.Checkbox(label="Cond. Free", value=True, info="Condition Free diffusion")
|
||||
with gr.Row():
|
||||
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
|
||||
layout["inference"]["inputs"]["diffusion-temp"] = gr.Slider(value=0.01, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (Diffusion)", info="Modifies the initial noise during the diffusion pass.")
|
||||
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.8, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
|
||||
layout["inference"]["inputs"]["diffusion-temp"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (Diffusion)", info="Modifies the initial noise during the diffusion pass.")
|
||||
"""
|
||||
with gr.Row():
|
||||
layout["inference"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
|
||||
|
|
Loading…
Reference in New Issue
Block a user