From 0b1a71430c0e9e0f56b4ae5c6b82220fecb63048 Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 19 Jun 2024 21:43:29 -0500 Subject: [PATCH] added BigVGAN and HiFiGAN (from https://git.ecker.tech/jarod/tortoise-tts), vocoder selectable in webUI --- data/config.yaml | 4 +- tortoise_tts/__main__.py | 3 + tortoise_tts/config.py | 2 + tortoise_tts/inference.py | 108 +- tortoise_tts/models/__init__.py | 52 +- tortoise_tts/models/arch_utils.py | 2 + tortoise_tts/models/bigvgan.py | 772 ++++++++++++ tortoise_tts/models/classifier.py | 2 + tortoise_tts/models/clvp.py | 2 + tortoise_tts/models/diffusion.py | 2 + tortoise_tts/models/hifigan.py | 305 +++++ tortoise_tts/models/lora.py | 1 + .../models/random_latent_generator.py | 3 + tortoise_tts/models/stream_generator.py | 1064 +++++++++++++++++ tortoise_tts/models/transformer.py | 2 + tortoise_tts/models/unified_voice.py | 49 +- tortoise_tts/models/vocoder.py | 2 + tortoise_tts/models/xtransformers.py | 2 + tortoise_tts/webui.py | 5 +- 19 files changed, 2364 insertions(+), 18 deletions(-) create mode 100644 tortoise_tts/models/bigvgan.py create mode 100644 tortoise_tts/models/hifigan.py create mode 100644 tortoise_tts/models/stream_generator.py diff --git a/data/config.yaml b/data/config.yaml index 5aa94a7..154f310 100644 --- a/data/config.yaml +++ b/data/config.yaml @@ -1,3 +1,5 @@ +vocoder: "hifigan" + models: - name: "autoregressive" training: True @@ -71,7 +73,7 @@ trainer: backend: deepspeed deepspeed: - inferencing: True + inferencing: False zero_optimization_level: 0 use_compression_training: False diff --git a/tortoise_tts/__main__.py b/tortoise_tts/__main__.py index 56441fa..c60dda6 100755 --- a/tortoise_tts/__main__.py +++ b/tortoise_tts/__main__.py @@ -24,6 +24,7 @@ def main(): parser.add_argument("--diffusion-sampler", type=str, default="ddim") parser.add_argument("--cond-free", action="store_true") + parser.add_argument("--vocoder", type=str, default="bigvgan") parser.add_argument("--yaml", type=Path, default=None) parser.add_argument("--device", type=str, default=None) @@ -62,6 +63,8 @@ def main(): diffusion_sampler=args.diffusion_sampler, cond_free=args.cond_free, + + vocoder_type=args.vocoder, ) """ language=args.language, diff --git a/tortoise_tts/config.py b/tortoise_tts/config.py index 3cb4aa4..7cbaa84 100755 --- a/tortoise_tts/config.py +++ b/tortoise_tts/config.py @@ -526,6 +526,8 @@ class Config(BaseConfig): sample_rate: int = 24_000 audio_backend: str = "mel" + vocoder: str = "bigvgan" # "vocoder" | "bigvgan" | "hifigan" + @property def model(self): for i, model in enumerate(self.models): diff --git a/tortoise_tts/inference.py b/tortoise_tts/inference.py index e19c431..b739321 100755 --- a/tortoise_tts/inference.py +++ b/tortoise_tts/inference.py @@ -5,6 +5,7 @@ import soundfile from torch import Tensor from einops import rearrange from pathlib import Path +from tqdm import tqdm from .emb.mel import encode_from_files as encode_mel, trim, trim_random from .utils import to_device @@ -96,6 +97,21 @@ class TTS(): # merge inputs return encode_mel( paths, device=self.device ) + # taken from here https://github.com/coqui-ai/TTS/blob/d21f15cc850788f9cdf93dac0321395138665287/TTS/tts/models/xtts.py#L666 + def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len): + """Handle chunk formatting in streaming mode""" + wav_chunk = wav_gen[:-overlap_len] + if wav_gen_prev is not None: + wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) : -overlap_len] + if wav_overlap is not None: + crossfade_wav = wav_chunk[:overlap_len] + crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device) + wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device) + wav_chunk[:overlap_len] += crossfade_wav + wav_overlap = wav_gen[-overlap_len:] + wav_gen_prev = wav_gen + return wav_chunk, wav_gen_prev, wav_overlap + @torch.inference_mode() def inference( self, @@ -122,7 +138,9 @@ class TTS(): diffusion_sampler="ddim", cond_free=True, - out_path=None + vocoder_type="bigvgan", + + out_path=None, ): lines = text.split("\n") @@ -142,7 +160,7 @@ class TTS(): diffusion = engine.module elif "clvp" in name: clvp = engine.module - elif "vocoder" in name: + elif vocoder_type in name: vocoder = engine.module if autoregressive is None: @@ -152,7 +170,7 @@ class TTS(): if clvp is None: clvp = load_model("clvp", device=cfg.device) if vocoder is None: - vocoder = load_model("vocoder", device=cfg.device) + vocoder = load_model(vocoder_type, device=cfg.device) autoregressive = autoregressive.to(cfg.device) diffusion = diffusion.to(cfg.device) @@ -183,6 +201,88 @@ class TTS(): text_tokens = pad_sequence([ text ], batch_first = True) text_lengths = torch.Tensor([ text.shape[0] ]).to(dtype=torch.int32) + # streaming interface spits out the final hidden state, which HiFiGAN seems to be trained against + if vocoder_type == "hifigan": + waves = [] + all_latents = [] + all_codes = [] + + wav_gen_prev = None + wav_overlap = None + is_end = False + first_buffer = 60 + + stream_chunk_size = 40 + overlap_wav_len = 1024 + + with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp): + with ml.auto_unload(autoregressive, enabled=cfg.inference.auto_unload): + with ml.auto_unload(vocoder, enabled=cfg.inference.auto_unload): + inputs = autoregressive.compute_embeddings( autoregressive_latents, text_tokens ) + + gpt_generator = autoregressive.get_generator( + inputs=inputs, + top_k=top_k, + top_p=top_p, + temperature=ar_temp, + do_sample=True, + num_beams=max(1, beam_width), + num_return_sequences=1, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + output_attentions=False, + output_hidden_states=True, + ) + + bar = tqdm( unit="it", total=500 ) + while not is_end: + try: + codes, latent = next(gpt_generator) + all_latents += [latent] + all_codes += [codes] + except StopIteration: + is_end = True + + if is_end or (stream_chunk_size > 0 and len(all_codes) >= max(stream_chunk_size, first_buffer)): + first_buffer = 0 + all_codes = [] + bar.update( stream_chunk_size ) + + latents = torch.cat(all_latents, dim=0)[None, :].to(cfg.device) + wav_gen = vocoder.inference(latents, autoregressive_latents) + wav_gen = wav_gen.squeeze() + + wav_chunk = wav_gen[:-overlap_wav_len] + if wav_gen_prev is not None: + wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_wav_len) : -overlap_wav_len] + if wav_overlap is not None: + crossfade_wav = wav_chunk[:overlap_wav_len] + crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_wav_len).to(crossfade_wav.device) + wav_chunk[:overlap_wav_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_wav_len).to(wav_overlap.device) + wav_chunk[:overlap_wav_len] += crossfade_wav + + wav_overlap = wav_gen[-overlap_wav_len:] + wav_gen_prev = wav_gen + + + # yielding requires to do a bunch of pain to work around it turning into an async function + """ + yield wav_chunk + """ + + waves.append( wav_chunk.unsqueeze(0) ) + + bar.close() + + wav = torch.concat(waves, dim=-1) + + if out_path is not None: + torchaudio.save( out_path, wav.cpu(), sr ) + + wavs.append(wav) + + continue + with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp): with ml.auto_unload(autoregressive, enabled=cfg.inference.auto_unload): # autoregressive pass @@ -190,9 +290,11 @@ class TTS(): autoregressive_latents, text_tokens, do_sample=True, + top_k=top_k, top_p=top_p, temperature=ar_temp, num_return_sequences=candidates, + num_beams=max(1,beam_width), length_penalty=length_penalty, repetition_penalty=repetition_penalty, max_generate_length=max_ar_steps, diff --git a/tortoise_tts/models/__init__.py b/tortoise_tts/models/__init__.py index d5da18d..3240afc 100755 --- a/tortoise_tts/models/__init__.py +++ b/tortoise_tts/models/__init__.py @@ -1,5 +1,4 @@ -# https://github.com/neonbjb/tortoise-tts/tree/98a891e66e7a1f11a830f31bd1ce06cc1f6a88af/tortoise/models -# All code under this folder is licensed as Apache License 2.0 per the original repo +# All other ccode in this folder are licensed per the attributions at the top from functools import cache @@ -8,6 +7,8 @@ from .arch_utils import TorchMelSpectrogram, TacotronSTFT from .unified_voice import UnifiedVoice from .diffusion import DiffusionTTS from .vocoder import UnivNetGenerator +from .bigvgan import BigVGAN +from .hifigan import HifiganGenerator from .clvp import CLVP from .dvae import DiscreteVAE from .random_latent_generator import RandomLatentConverter @@ -15,6 +16,8 @@ from .random_latent_generator import RandomLatentConverter import os import torch from pathlib import Path +import requests +from tqdm import tqdm DEFAULT_MODEL_PATH = Path(__file__).parent.parent.parent / 'data/models' DEFAULT_MODEL_URLS = { @@ -28,10 +31,18 @@ DEFAULT_MODEL_URLS = { 'rlg_auto.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth', 'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth', 'mel_norms.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/data/mel_norms.pth', + + # BigVGAN + 'bigvgan_base_24khz_100band.pth': 'https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_base_24khz_100band.pth', + 'bigvgan_24khz_100band.pth': 'https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_24khz_100band.pth', + + 'bigvgan_base_24khz_100band.json': 'https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_base_24khz_100band.json', + 'bigvgan_24khz_100band.json': 'https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_24khz_100band.json', + + # HiFiGAN + 'hifigan.pth': 'https://huggingface.co/Manmay/tortoise-tts/resolve/main/hifidecoder.pth', } -import requests -from tqdm import tqdm # kludge, probably better to use HF's model downloader function # to-do: write to a temp file then copy so downloads can be interrupted @@ -75,6 +86,7 @@ def download_model( save_path, chunkSize = 1024, unit = "MiB" ): @cache def load_model(name, device="cuda", **kwargs): load_path = None + config_path = None state_dict_key = None strict = True @@ -95,6 +107,31 @@ def load_model(name, device="cuda", **kwargs): elif "clvp" in name: model = CLVP(**kwargs) load_path = DEFAULT_MODEL_PATH / 'clvp2.pth' + elif "bigvgan" in name: + # download any JSONs (BigVGAN) + load_path = DEFAULT_MODEL_PATH / 'bigvgan_24khz_100band.pth' + config_path = load_path.with_suffix(".json") + if config_path.name in DEFAULT_MODEL_URLS: + if not config_path.exists(): + download_model( config_path ) + else: + config_path = None + + model = BigVGAN(config=config_path, **kwargs) + state_dict_key = 'generator' + elif "hifigan" in name: + model = HifiganGenerator( + in_channels=1024, + out_channels = 1, + resblock_type = "1", + resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + resblock_kernel_sizes = [3, 7, 11], + upsample_kernel_sizes = [16, 16, 4, 4], + upsample_initial_channel = 512, + upsample_factors = [8, 8, 2, 2], + cond_channels=1024 + ) + load_path = DEFAULT_MODEL_PATH / 'hifigan.pth' elif "vocoder" in name: model = UnivNetGenerator(**kwargs) load_path = DEFAULT_MODEL_PATH / 'vocoder.pth' @@ -126,6 +163,11 @@ def load_model(name, device="cuda", **kwargs): model.eval() + try: + print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters") + except Exception as e: + print(f"{name}: {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters") + return model def unload_model(): @@ -138,8 +180,6 @@ def get_model(config, training=True): config.training = "autoregressive" in config.name model.config = config - print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters") - return model def get_models(models, training=True): diff --git a/tortoise_tts/models/arch_utils.py b/tortoise_tts/models/arch_utils.py index 871f246..4993dd6 100644 --- a/tortoise_tts/models/arch_utils.py +++ b/tortoise_tts/models/arch_utils.py @@ -1,3 +1,5 @@ +# Adapted from https://github.com/neonbjb/tortoise-tts/tree/98a891e66e7a1f11a830f31bd1ce06cc1f6a88af/tortoise/models/arch_utils.py + import os import functools import math diff --git a/tortoise_tts/models/bigvgan.py b/tortoise_tts/models/bigvgan.py new file mode 100644 index 0000000..a3478e4 --- /dev/null +++ b/tortoise_tts/models/bigvgan.py @@ -0,0 +1,772 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import math + +import json +import os +import torch.utils.data + +from torch import nn, sin, pow +from torch.nn import Conv1d, ConvTranspose1d, Conv2d, Parameter +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm +from librosa.filters import mel as librosa_mel_fn + + +# filter.py +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +if 'sinc' in dir(torch): + sinc = torch.sinc +else: + # This code is adopted from adefossez's julius.core.sinc under the MIT License + # https://adefossez.github.io/julius/julius/core.html + # LICENSE is in incl_licenses directory. + def sinc(x: torch.Tensor): + """ + Implementation of sinc, i.e. sin(pi * x) / (pi * x) + __Warning__: Different to julius.sinc, the input is multiplied by `pi`! + """ + return torch.where(x == 0, + torch.tensor(1., device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x) + + +# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License +# https://adefossez.github.io/julius/julius/lowpass.html +# LICENSE is in incl_licenses directory. +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] + even = (kernel_size % 2 == 0) + half_size = kernel_size // 2 + + #For kaiser window + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.: + beta = 0.1102 * (A - 8.7) + elif A >= 21.: + beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.) + else: + beta = 0. + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + + # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio + if even: + time = (torch.arange(-half_size, half_size) + 0.5) + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) + # Normalize filter to have sum = 1, otherwise we will have a small leakage + # of the constant component in the input signal. + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + + return filter + + +class LowPassFilter1d(nn.Module): + def __init__(self, + cutoff=0.5, + half_width=0.6, + stride: int = 1, + padding: bool = True, + padding_mode: str = 'replicate', + kernel_size: int = 12): + # kernel_size should be even number for stylegan3 setup, + # in this implementation, odd number is also possible. + super().__init__() + if cutoff < -0.: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = (kernel_size % 2 == 0) + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + + #input [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), + mode=self.padding_mode) + out = F.conv1d(x, self.filter.expand(C, -1, -1), + stride=self.stride, groups=C) + + return out + +# resample.py +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + kernel_size=self.kernel_size) + self.register_buffer("filter", filter) + + # x: [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + x = F.pad(x, (self.pad, self.pad), mode='replicate') + x = self.ratio * F.conv_transpose1d( + x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + x = x[..., self.pad_left:-self.pad_right] + + return x + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size) + + def forward(self, x): + xx = self.lowpass(x) + + return xx + +# act.py +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +class Activation1d(nn.Module): + def __init__(self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + # x: [B,C,T] + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + + return x + +# activations.py +# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. +# LICENSE is in incl_licenses directory. + +class Snake(nn.Module): + ''' + Implementation of a sine-based periodic activation function + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter + References: + - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snake(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha: trainable parameter + alpha is initialized to 1 by default, higher values = higher-frequency. + alpha will be trained along with the rest of your model. + ''' + super(Snake, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + Snake ∶= x + 1/a * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + if self.alpha_logscale: + alpha = torch.exp(alpha) + x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x + + +class SnakeBeta(nn.Module): + ''' + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + ''' + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.beta = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + self.beta = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x + +# bigvgan.py +# Copyright (c) 2022 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + +LRELU_SLOPE = 0.1 + +class AMPBlock1(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None): + super(AMPBlock1, self).__init__() + self.h = h + + self.convs1 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))) + ]) + self.convs2.apply(init_weights) + + self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers + + if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing + self.activations = nn.ModuleList([ + Activation1d( + activation=Snake(channels, alpha_logscale=h.snake_logscale)) + for _ in range(self.num_layers) + ]) + elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing + self.activations = nn.ModuleList([ + Activation1d( + activation=SnakeBeta(channels, alpha_logscale=h.snake_logscale)) + for _ in range(self.num_layers) + ]) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'.") + + def forward(self, x): + acts1, acts2 = self.activations[::2], self.activations[1::2] + for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2): + xt = a1(x) + xt = c1(xt) + xt = a2(xt) + xt = c2(xt) + x = xt + x + + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class AMPBlock2(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None): + super(AMPBlock2, self).__init__() + self.h = h + + self.convs = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))) + ]) + self.convs.apply(init_weights) + + self.num_layers = len(self.convs) # total number of conv layers + + if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing + self.activations = nn.ModuleList([ + Activation1d( + activation=Snake(channels, alpha_logscale=h.snake_logscale)) + for _ in range(self.num_layers) + ]) + elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing + self.activations = nn.ModuleList([ + Activation1d( + activation=SnakeBeta(channels, alpha_logscale=h.snake_logscale)) + for _ in range(self.num_layers) + ]) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'.") + + def forward(self, x): + for c, a in zip(self.convs, self.activations): + xt = a(x) + xt = c(xt) + x = xt + x + + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + +class BigVGAN(nn.Module): + # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks. + def __init__(self, config=None, data=None): + super(BigVGAN, self).__init__() + + """ + with open(os.path.join(os.path.dirname(__file__), 'config.json'), 'r') as f: + data = f.read() + """ + if config and data is None: + with open(config, 'r') as f: + data = f.read() + jsonConfig = json.loads(data) + elif data is not None: + if isinstance(data, str): + jsonConfig = json.loads(data) + else: + jsonConfig = data + else: + raise Exception("no config specified") + + + global h + h = AttrDict(jsonConfig) + + self.mel_channel = h.num_mels + self.noise_dim = h.n_fft + self.hop_length = h.hop_size + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + + # pre conv + self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)) + + # define which AMPBlock to use. BigVGAN uses AMPBlock1 as default + resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2 + + # transposed conv-based upsamplers. does not apply anti-aliasing + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append(nn.ModuleList([ + weight_norm(ConvTranspose1d(h.upsample_initial_channel // (2 ** i), + h.upsample_initial_channel // (2 ** (i + 1)), + k, u, padding=(k - u) // 2)) + ])) + + # residual blocks using anti-aliased multi-periodicity composition modules (AMP) + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): + self.resblocks.append(resblock(h, ch, k, d, activation=h.activation)) + + # post conv + if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing + activation_post = Snake(ch, alpha_logscale=h.snake_logscale) + self.activation_post = Activation1d(activation=activation_post) + elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing + activation_post = SnakeBeta(ch, alpha_logscale=h.snake_logscale) + self.activation_post = Activation1d(activation=activation_post) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'.") + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + + # weight initialization + for i in range(len(self.ups)): + self.ups[i].apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self,x, c): + # pre conv + x = self.conv_pre(x) + + for i in range(self.num_upsamples): + # upsampling + for i_up in range(len(self.ups[i])): + x = self.ups[i][i_up](x) + # AMP blocks + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + + # post conv + x = self.activation_post(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + for l_i in l: + remove_weight_norm(l_i) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + def inference(self, c, z=None): + # pad input mel with zeros to cut artifact + # see https://github.com/seungwonpark/melgan/issues/8 + zero = torch.full((c.shape[0], h.num_mels, 10), -11.5129).to(c.device) + mel = torch.cat((c, zero), dim=2) + + if z is None: + z = torch.randn(c.shape[0], self.noise_dim, mel.size(2)).to(mel.device) + + audio = self.forward(mel, z) + audio = audio[:, :, :-(self.hop_length * 10)] + audio = audio.clamp(min=-1, max=1) + return audio + + def eval(self, inference=False): + super(BigVGAN, self).eval() + # don't remove weight norm while validation in training loop + if inference: + self.remove_weight_norm() + + +class DiscriminatorP(nn.Module): + def __init__(self, h, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.d_mult = h.discriminator_channel_mult + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv2d(1, int(32 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(int(32 * self.d_mult), int(128 * self.d_mult), (kernel_size, 1), (stride, 1), + padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(int(128 * self.d_mult), int(512 * self.d_mult), (kernel_size, 1), (stride, 1), + padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(int(512 * self.d_mult), int(1024 * self.d_mult), (kernel_size, 1), (stride, 1), + padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(int(1024 * self.d_mult), int(1024 * self.d_mult), (kernel_size, 1), 1, padding=(2, 0))), + ]) + self.conv_post = norm_f(Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(nn.Module): + def __init__(self, h): + super(MultiPeriodDiscriminator, self).__init__() + self.mpd_reshapes = h.mpd_reshapes + print("mpd_reshapes: {}".format(self.mpd_reshapes)) + discriminators = [DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm) for rs in self.mpd_reshapes] + self.discriminators = nn.ModuleList(discriminators) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorR(nn.Module): + def __init__(self, cfg, resolution): + super().__init__() + + self.resolution = resolution + assert len(self.resolution) == 3, \ + "MRD layer requires list with len=3, got {}".format(self.resolution) + self.lrelu_slope = LRELU_SLOPE + + norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm + if hasattr(cfg, "mrd_use_spectral_norm"): + print("INFO: overriding MRD use_spectral_norm as {}".format(cfg.mrd_use_spectral_norm)) + norm_f = weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm + self.d_mult = cfg.discriminator_channel_mult + if hasattr(cfg, "mrd_channel_mult"): + print("INFO: overriding mrd channel multiplier as {}".format(cfg.mrd_channel_mult)) + self.d_mult = cfg.mrd_channel_mult + + self.convs = nn.ModuleList([ + norm_f(nn.Conv2d(1, int(32 * self.d_mult), (3, 9), padding=(1, 4))), + norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 3), padding=(1, 1))), + ]) + self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1))) + + def forward(self, x): + fmap = [] + + x = self.spectrogram(x) + x = x.unsqueeze(1) + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, self.lrelu_slope) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + def spectrogram(self, x): + n_fft, hop_length, win_length = self.resolution + x = F.pad(x, (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), mode='reflect') + x = x.squeeze(1) + x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False, return_complex=True) + x = torch.view_as_real(x) # [B, F, TT, 2] + mag = torch.norm(x, p=2, dim=-1) # [B, F, TT] + + return mag + + +class MultiResolutionDiscriminator(nn.Module): + def __init__(self, cfg, debug=False): + super().__init__() + self.resolutions = cfg.resolutions + assert len(self.resolutions) == 3, \ + "MRD requires list of list with len=3, each element having a list with len=3. got {}". \ + format(self.resolutions) + self.discriminators = nn.ModuleList( + [DiscriminatorR(cfg, resolution) for resolution in self.resolutions] + ) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(x=y) + y_d_g, fmap_g = d(x=y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + +def get_mel(x): + return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax) + +def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + if torch.min(y) < -1.: + print('min value is ', torch.min(y)) + if torch.max(y) > 1.: + print('max value is ', torch.max(y)) + + global mel_basis, hann_window + if fmax not in mel_basis: + mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) + mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') + y = y.squeeze(1) + + # complex tensor as default, then use view_as_real for future pytorch compatibility + spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], + center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True) + spec = torch.view_as_real(spec) + spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) + + spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec) + spec = torch.nn.utils.spectral_normalize_torch(spec) + + return spec + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg ** 2) + loss += (r_loss + g_loss) + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + +if __name__ == '__main__': + model = BigVGAN() + + c = torch.randn(3, 100, 10) + z = torch.randn(3, 64, 10) + print(c.shape) + + y = model(c, z) + print(y.shape) + assert y.shape == torch.Size([3, 1, 2560]) + + pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(pytorch_total_params) \ No newline at end of file diff --git a/tortoise_tts/models/classifier.py b/tortoise_tts/models/classifier.py index cf00f3a..10331f2 100644 --- a/tortoise_tts/models/classifier.py +++ b/tortoise_tts/models/classifier.py @@ -1,3 +1,5 @@ +# Copied from https://github.com/neonbjb/tortoise-tts/tree/98a891e66e7a1f11a830f31bd1ce06cc1f6a88af/tortoise/models/classifier.py + import torch import torch.nn as nn diff --git a/tortoise_tts/models/clvp.py b/tortoise_tts/models/clvp.py index e6c43c4..b2ccda8 100644 --- a/tortoise_tts/models/clvp.py +++ b/tortoise_tts/models/clvp.py @@ -1,3 +1,5 @@ +# Copied from https://github.com/neonbjb/tortoise-tts/tree/98a891e66e7a1f11a830f31bd1ce06cc1f6a88af/tortoise/models/clvp.py + import torch import torch.nn as nn import torch.nn.functional as F diff --git a/tortoise_tts/models/diffusion.py b/tortoise_tts/models/diffusion.py index 00c7c1e..c0137a3 100644 --- a/tortoise_tts/models/diffusion.py +++ b/tortoise_tts/models/diffusion.py @@ -1,3 +1,5 @@ +# Adapted from https://github.com/neonbjb/tortoise-tts/tree/98a891e66e7a1f11a830f31bd1ce06cc1f6a88af/tortoise/models/diffusion.py + import enum import math import random diff --git a/tortoise_tts/models/hifigan.py b/tortoise_tts/models/hifigan.py new file mode 100644 index 0000000..0883b3c --- /dev/null +++ b/tortoise_tts/models/hifigan.py @@ -0,0 +1,305 @@ +# Grabbed from https://git.ecker.tech/Jarod/tortoise-tts/src/branch/main +# Adapted from https://github.com/jik876/hifi-gan/blob/master/models.py + +import torch +from torch import nn +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn import functional as F +from torch.nn.utils import remove_weight_norm, weight_norm + +LRELU_SLOPE = 0.1 + + +def get_padding(k, d): + return int((k * d - d) / 2) + + +class ResBlock1(torch.nn.Module): + """Residual Block Type 1. It has 3 convolutional layers in each convolutional block. + + Network:: + + x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o + |--------------------------------------------------------------------------------------------------| + + + Args: + channels (int): number of hidden channels for the convolutional layers. + kernel_size (int): size of the convolution filter in each layer. + dilations (list): list of dilation value for each conv layer in a block. + """ + + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super().__init__() + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + ] + ) + + def forward(self, x): + """ + Args: + x (Tensor): input tensor. + Returns: + Tensor: output tensor. + Shapes: + x: [B, C, T] + """ + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + """Residual Block Type 2. It has 1 convolutional layers in each convolutional block. + + Network:: + + x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o + |---------------------------------------------------| + + + Args: + channels (int): number of hidden channels for the convolutional layers. + kernel_size (int): size of the convolution filter in each layer. + dilations (list): list of dilation value for each conv layer in a block. + """ + + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super().__init__() + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class HifiganGenerator(torch.nn.Module): + def __init__( + self, + in_channels, + out_channels, + resblock_type, + resblock_dilation_sizes, + resblock_kernel_sizes, + upsample_kernel_sizes, + upsample_initial_channel, + upsample_factors, + inference_padding=5, + cond_channels=0, + conv_pre_weight_norm=True, + conv_post_weight_norm=True, + conv_post_bias=True, + ): + r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF) + + Network: + x -> lrelu -> upsampling_layer -> resblock1_k1x1 -> z1 -> + -> z_sum / #resblocks -> lrelu -> conv_post_7x1 -> tanh -> o + .. -> zI ---| + resblockN_kNx1 -> zN ---' + + Args: + in_channels (int): number of input tensor channels. + out_channels (int): number of output tensor channels. + resblock_type (str): type of the `ResBlock`. '1' or '2'. + resblock_dilation_sizes (List[List[int]]): list of dilation values in each layer of a `ResBlock`. + resblock_kernel_sizes (List[int]): list of kernel sizes for each `ResBlock`. + upsample_kernel_sizes (List[int]): list of kernel sizes for each transposed convolution. + upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2 + for each consecutive upsampling layer. + upsample_factors (List[int]): upsampling factors (stride) for each upsampling layer. + inference_padding (int): constant padding applied to the input at inference time. Defaults to 5. + """ + super().__init__() + self.inference_padding = inference_padding + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_factors) + # initial upsampling layers + self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)) + resblock = ResBlock1 if resblock_type == "1" else ResBlock2 + # upsampling layers + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + # MRF blocks + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock(ch, k, d)) + # post convolution layer + self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias)) + if cond_channels > 0: + self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1) + + if not conv_pre_weight_norm: + remove_weight_norm(self.conv_pre) + + if not conv_post_weight_norm: + remove_weight_norm(self.conv_post) + + self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu') + if torch.backends.mps.is_available(): + self.device = torch.device('mps') + + def forward(self, x, g=None): + """ + Args: + x (Tensor): feature input tensor. + g (Tensor): global conditioning input tensor. + + Returns: + Tensor: output waveform. + + Shapes: + x: [B, C, T] + Tensor: [B, 1, T] + """ + o = self.conv_pre(x) + if hasattr(self, "cond_layer"): + o = o + self.cond_layer(g) + for i in range(self.num_upsamples): + o = F.leaky_relu(o, LRELU_SLOPE) + o = self.ups[i](o) + z_sum = None + for j in range(self.num_kernels): + if z_sum is None: + z_sum = self.resblocks[i * self.num_kernels + j](o) + else: + z_sum += self.resblocks[i * self.num_kernels + j](o) + o = z_sum / self.num_kernels + o = F.leaky_relu(o) + o = self.conv_post(o) + o = torch.tanh(o) + return o + + @torch.no_grad() + def inference(self, c, g=None): + """ + Args: + x (Tensor): conditioning input tensor. + + Returns: + Tensor: output waveform. + + Shapes: + x: [B, C, T] + Tensor: [B, 1, T] + """ + # c = c.to(self.conv_pre.weight.device) + # c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate") + up_1 = torch.nn.functional.interpolate( + c.transpose(1,2), + scale_factor=[1024 / 256], + mode="linear", + ) + up_2 = torch.nn.functional.interpolate( + up_1, + scale_factor=[24000 / 22050], + mode="linear", + ) + g = g.unsqueeze(0) + return self.forward(up_2.to(self.device), g.transpose(1,2)) + + def remove_weight_norm(self): + print("Removing weight norm...") + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) \ No newline at end of file diff --git a/tortoise_tts/models/lora.py b/tortoise_tts/models/lora.py index 30ca4c3..cee0605 100644 --- a/tortoise_tts/models/lora.py +++ b/tortoise_tts/models/lora.py @@ -1,4 +1,5 @@ # Adapted from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py + from functools import partial import torch import torch.nn.functional as F diff --git a/tortoise_tts/models/random_latent_generator.py b/tortoise_tts/models/random_latent_generator.py index 4d6ff80..6fa4494 100644 --- a/tortoise_tts/models/random_latent_generator.py +++ b/tortoise_tts/models/random_latent_generator.py @@ -1,3 +1,6 @@ +# Copied from https://github.com/neonbjb/tortoise-tts/tree/98a891e66e7a1f11a830f31bd1ce06cc1f6a88af/tortoise/models/random_latent_generator.py + + import math import torch diff --git a/tortoise_tts/models/stream_generator.py b/tortoise_tts/models/stream_generator.py new file mode 100644 index 0000000..5f41407 --- /dev/null +++ b/tortoise_tts/models/stream_generator.py @@ -0,0 +1,1064 @@ +# Adapted from https://git.ecker.tech/Jarod/tortoise-tts/raw/commit/156bb5e7da98128071ab94fd4f3884e64566f3dc/tortoise/models/stream_generator.py +# Adapted from: https://github.com/LowinLi/transformers-stream-generator + +from transformers import ( + GenerationConfig, + GenerationMixin, + LogitsProcessorList, + StoppingCriteriaList, + DisjunctiveConstraint, + BeamSearchScorer, + PhrasalConstraint, + ConstrainedBeamSearchScorer, + PreTrainedModel, +) +import numpy as np +import random +import warnings +import inspect +from transformers.generation.utils import GenerateOutput, SampleOutput, logger +import torch +from typing import Callable, List, Optional, Union +from torch import nn +import torch.distributed as dist +import copy + + +def setup_seed(seed): + if seed == -1: + return + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + + +class StreamGenerationConfig(GenerationConfig): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.do_stream = kwargs.pop("do_stream", False) + + +class NewGenerationMixin(GenerationMixin): + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[StreamGenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[ + Callable[[int, torch.Tensor], List[int]] + ] = None, + synced_gpus: Optional[bool] = False, + seed=0, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + r""" + + Generates sequences of token ids for models with a language modeling head. + + + + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Parameters: + inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): + The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the + method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` + should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of + `input_ids`, `input_values`, `input_features`, or `pixel_values`. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complement the default stopping criteria built from arguments and a + generation config. If a stopping criteria is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): + If provided, this function constraints the beam search to allowed tokens only at each step. If not + provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and + `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned + on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful + for constrained generation conditioned on the prefix, as described in [Autoregressive Entity + Retrieval](https://arxiv.org/abs/2010.00904). + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + kwargs: + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder + specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. + + Return: + [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. + + If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GreedySearchDecoderOnlyOutput`], + - [`~generation.SampleDecoderOnlyOutput`], + - [`~generation.BeamSearchDecoderOnlyOutput`], + - [`~generation.BeamSampleDecoderOnlyOutput`] + + If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GreedySearchEncoderDecoderOutput`], + - [`~generation.SampleEncoderDecoderOutput`], + - [`~generation.BeamSearchEncoderDecoderOutput`], + - [`~generation.BeamSampleEncoderDecoderOutput`] + """ + setup_seed(seed) + # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call + self._validate_model_class() + + # priority: `generation_config` argument > `model.generation_config` (the default generation config) + if generation_config is None: + # legacy: users may modify the model configuration to control generation -- update the generation config + # model attribute accordingly, if it was created from the model config + if self.generation_config._from_model_config: + new_generation_config = StreamGenerationConfig.from_model_config( + self.config + ) + if new_generation_config != self.generation_config: + warnings.warn( + "You have modified the pretrained model configuration to control generation. This is a" + " deprecated strategy to control generation and will be removed soon, in a future version." + " Please use a generation configuration file (see" + " https://huggingface.co/docs/transformers/main_classes/text_generation)" + ) + self.generation_config = new_generation_config + generation_config = self.generation_config + + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update( + **kwargs + ) # All unused kwargs must be model kwargs + # self._validate_model_kwargs(model_kwargs.copy()) + + # 2. Set generation parameters if not already defined + logits_processor = ( + logits_processor if logits_processor is not None else LogitsProcessorList() + ) + stopping_criteria = ( + stopping_criteria + if stopping_criteria is not None + else StoppingCriteriaList() + ) + + if ( + generation_config.pad_token_id is None + and generation_config.eos_token_id is not None + ): + if model_kwargs.get("attention_mask", None) is None: + logger.warning( + "The attention mask and the pad token id were not set. As a consequence, you may observe " + "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." + ) + eos_token_id = generation_config.eos_token_id + if isinstance(eos_token_id, list): + eos_token_id = eos_token_id[0] + logger.warning( + f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation." + ) + generation_config.pad_token_id = eos_token_id + + # 3. Define model inputs + # inputs_tensor has to be defined + # model_input_name is defined if model-specific keyword input is passed + # otherwise model_input_name is None + # all model-specific keyword inputs are removed from `model_kwargs` + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( + inputs, generation_config.bos_token_id, model_kwargs + ) + batch_size = inputs_tensor.shape[0] + + # 4. Define other model kwargs + model_kwargs["output_attentions"] = generation_config.output_attentions + model_kwargs["output_hidden_states"] = generation_config.output_hidden_states + model_kwargs["use_cache"] = generation_config.use_cache + + accepts_attention_mask = "attention_mask" in set( + inspect.signature(self.forward).parameters.keys() + ) + requires_attention_mask = "encoder_outputs" not in model_kwargs + + if ( + model_kwargs.get("attention_mask", None) is None + and requires_attention_mask + and accepts_attention_mask + ): + pad_token_tensor = torch.tensor([generation_config.pad_token_id], device=inputs_tensor.device) if generation_config.pad_token_id is not None else None + eos_token_tensor = torch.tensor([generation_config.eos_token_id], device=inputs_tensor.device) if generation_config.eos_token_id is not None else None + + model_kwargs[ + "attention_mask" + ] = self._prepare_attention_mask_for_generation( + inputs_tensor, + pad_token_tensor, + eos_token_tensor, + ) + + # decoder-only models should use left-padding for generation + if not self.config.is_encoder_decoder: + if ( + generation_config.pad_token_id is not None + and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) + > 0 + ): + logger.warning( + "A decoder-only architecture is being used, but right-padding was detected! For correct " + "generation results, please set `padding_side='left'` when initializing the tokenizer." + ) + + if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: + # if model is encoder decoder encoder_outputs are created + # and added to `model_kwargs` + model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, model_kwargs, model_input_name + ) + + # 5. Prepare `input_ids` which will be used for auto-regressive generation + if self.config.is_encoder_decoder: + input_ids = self._prepare_decoder_input_ids_for_generation( + batch_size, + decoder_start_token_id=generation_config.decoder_start_token_id, + bos_token_id=generation_config.bos_token_id, + model_kwargs=model_kwargs, + device=inputs_tensor.device, + ) + else: + # if decoder-only then inputs_tensor has to be `input_ids` + input_ids = inputs_tensor + + # 6. Prepare `max_length` depending on other stopping criteria. + input_ids_seq_length = input_ids.shape[-1] + has_default_max_length = ( + kwargs.get("max_length") is None + and generation_config.max_length is not None + ) + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + "Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to" + f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the" + " config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", + UserWarning, + ) + elif has_default_max_length and generation_config.max_new_tokens is not None: + generation_config.max_length = ( + generation_config.max_new_tokens + input_ids_seq_length + ) + elif ( + not has_default_max_length and generation_config.max_new_tokens is not None + ): + raise ValueError( + "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a" + " limit to the generated output length. Remove one of those arguments. Please refer to the" + " documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" + ) + + if ( + generation_config.min_length is not None + and generation_config.min_length > generation_config.max_length + ): + raise ValueError( + f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than" + f" the maximum length ({generation_config.max_length})" + ) + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = ( + "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + ) + logger.warning( + f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." + ) + + if not hasattr(generation_config, "do_stream"): + generation_config.do_stream = kwargs.get("do_stream", False) + + # 7. determine generation mode + is_constraint_gen_mode = ( + generation_config.constraints is not None + or generation_config.force_words_ids is not None + ) and generation_config.do_stream is False + + is_contrastive_search_gen_mode = ( + generation_config.top_k is not None + and generation_config.top_k > 1 + and generation_config.do_sample is False + and generation_config.penalty_alpha is not None + and generation_config.penalty_alpha > 0 + ) and generation_config.do_stream is False + + is_greedy_gen_mode = ( + (generation_config.num_beams == 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is False + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) and generation_config.do_stream is False + is_sample_gen_mode = ( + (generation_config.num_beams == 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is True + and generation_config.do_stream is False + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) and generation_config.do_stream is False + is_sample_gen_stream_mode = ( + (generation_config.num_beams == 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_stream is True + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_beam_gen_mode = ( + (generation_config.num_beams > 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is False + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) and generation_config.do_stream is False + is_beam_sample_gen_mode = ( + (generation_config.num_beams > 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is True + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) and generation_config.do_stream is False + is_group_beam_gen_mode = ( + (generation_config.num_beams > 1) + and (generation_config.num_beam_groups > 1) + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) and generation_config.do_stream is False + + if generation_config.num_beam_groups > generation_config.num_beams: + raise ValueError( + "`num_beam_groups` has to be smaller or equal to `num_beams`" + ) + if is_group_beam_gen_mode and generation_config.do_sample is True: + raise ValueError( + "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." + ) + + if self.device.type != input_ids.device.type: + warnings.warn( + "You are calling .generate() with the `input_ids` being on a device type different" + f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" + f" is on {self.device.type}. You may experience unexpected behaviors or slower generation." + " Please make sure that you have put `input_ids` to the" + f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before" + " running `.generate()`.", + UserWarning, + ) + # 8. prepare distribution pre_processing samplers + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=inputs_tensor, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + # 9. prepare stopping criteria + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) + # 10. go into different generation modes + if is_greedy_gen_mode: + if generation_config.num_return_sequences > 1: + raise ValueError( + f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" + " greedy search." + ) + + # 11. run greedy search + return self.greedy_search( + input_ids, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_contrastive_search_gen_mode: + if generation_config.num_return_sequences > 1: + raise ValueError( + f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" + " contrastive search." + ) + + return self.contrastive_search( + input_ids, + top_k=generation_config.top_k, + penalty_alpha=generation_config.penalty_alpha, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_sample_gen_mode: + # 11. prepare logits warper + logits_warper = self._get_logits_warper(generation_config, input_ids.device) + + # 12. expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 13. run sample + return self._sample( + input_ids, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + streamer=None, + generation_config=generation_config, + **model_kwargs, + ) + elif is_sample_gen_stream_mode: + # 11. prepare logits warper + logits_warper = self._get_logits_warper(generation_config, input_ids.device) + + # 12. expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 13. run sample + return self.sample_stream( + input_ids, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + elif is_beam_gen_mode: + if generation_config.num_return_sequences > generation_config.num_beams: + raise ValueError( + "`num_return_sequences` has to be smaller or equal to `num_beams`." + ) + + if stopping_criteria.max_length is None: + raise ValueError( + "`max_length` needs to be a stopping_criteria for now." + ) + + # 11. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + ) + # 12. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + # 13. run beam search + return self.beam_search( + input_ids, + beam_scorer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_beam_sample_gen_mode: + # 11. prepare logits warper + logits_warper = self._get_logits_warper(generation_config, input_ids.device) + + if stopping_criteria.max_length is None: + raise ValueError( + "`max_length` needs to be a stopping_criteria for now." + ) + # 12. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size * generation_config.num_return_sequences, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + ) + + # 13. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams + * generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 14. run beam sample + return self.beam_sample( + input_ids, + beam_scorer, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_group_beam_gen_mode: + if generation_config.num_return_sequences > generation_config.num_beams: + raise ValueError( + "`num_return_sequences` has to be smaller or equal to `num_beams`." + ) + + if generation_config.num_beams % generation_config.num_beam_groups != 0: + raise ValueError( + "`num_beams` should be divisible by `num_beam_groups` for group beam search." + ) + + if stopping_criteria.max_length is None: + raise ValueError( + "`max_length` needs to be a stopping_criteria for now." + ) + + has_default_typical_p = ( + kwargs.get("typical_p") is None and generation_config.typical_p == 1.0 + ) + if not has_default_typical_p: + raise ValueError( + "Decoder argument `typical_p` is not supported with beam groups." + ) + + # 11. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=generation_config.num_beams, + max_length=stopping_criteria.max_length, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + num_beam_groups=generation_config.num_beam_groups, + ) + # 12. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + # 13. run beam search + return self.group_beam_search( + input_ids, + beam_scorer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_constraint_gen_mode: + if generation_config.num_return_sequences > generation_config.num_beams: + raise ValueError( + "`num_return_sequences` has to be smaller or equal to `num_beams`." + ) + + if stopping_criteria.max_length is None: + raise ValueError( + "`max_length` needs to be a stopping_criteria for now." + ) + + if generation_config.num_beams <= 1: + raise ValueError( + "`num_beams` needs to be greater than 1 for constrained generation." + ) + + if generation_config.do_sample: + raise ValueError( + "`do_sample` needs to be false for constrained generation." + ) + + if ( + generation_config.num_beam_groups is not None + and generation_config.num_beam_groups > 1 + ): + raise ValueError( + "`num_beam_groups` not supported yet for constrained generation." + ) + + final_constraints = [] + if generation_config.constraints is not None: + final_constraints = generation_config.constraints + + if generation_config.force_words_ids is not None: + + def typeerror(): + raise ValueError( + "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`" + f"of positive integers, but is {generation_config.force_words_ids}." + ) + + if ( + not isinstance(generation_config.force_words_ids, list) + or len(generation_config.force_words_ids) == 0 + ): + typeerror() + + for word_ids in generation_config.force_words_ids: + if isinstance(word_ids[0], list): + if not isinstance(word_ids, list) or len(word_ids) == 0: + typeerror() + if any( + not isinstance(token_ids, list) for token_ids in word_ids + ): + typeerror() + if any( + any( + (not isinstance(token_id, int) or token_id < 0) + for token_id in token_ids + ) + for token_ids in word_ids + ): + typeerror() + + constraint = DisjunctiveConstraint(word_ids) + else: + if not isinstance(word_ids, list) or len(word_ids) == 0: + typeerror() + if any( + (not isinstance(token_id, int) or token_id < 0) + for token_id in word_ids + ): + typeerror() + + constraint = PhrasalConstraint(word_ids) + final_constraints.append(constraint) + + # 11. prepare beam search scorer + constrained_beam_scorer = ConstrainedBeamSearchScorer( + constraints=final_constraints, + batch_size=batch_size, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + ) + # 12. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + # 13. run beam search + return self.constrained_beam_search( + input_ids, + constrained_beam_scorer=constrained_beam_scorer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + @torch.no_grad() + def sample_stream( + self, + input_ids: torch.LongTensor, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + logits_warper: Optional[LogitsProcessorList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: Optional[bool] = False, + **model_kwargs, + ) -> Union[SampleOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and + can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + + + In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead. + For an overview of generation strategies and code examples, check the [following + guide](./generation_strategies). + + + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + logits_processor (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`, *optional*): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + logits_warper (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used + to warp the prediction score distribution of the language modeling head applied before multinomial + sampling at each generation step. + max_length (`int`, *optional*, defaults to 20): + **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated + tokens. The maximum length of the sequence to be generated. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`int`, *optional*): + The id of the *end-of-sequence* token. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.SampleDecoderOnlyOutput`], [`~generation.SampleEncoderDecoderOutput`] or `torch.LongTensor`: + A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.SampleEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + + Examples: + + ```python + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForCausalLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... TopKLogitsWarper, + ... TemperatureLogitsWarper, + ... StoppingCriteriaList, + ... MaxLengthCriteria, + ... ) + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + + >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token + >>> model.config.pad_token_id = model.config.eos_token_id + >>> model.generation_config.pad_token_id = model.config.eos_token_id + + >>> input_prompt = "Today is a beautiful day, and" + >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids + + >>> # instantiate logits processors + >>> logits_processor = LogitsProcessorList( + ... [ + ... MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id), + ... ] + ... ) + >>> # instantiate logits processors + >>> logits_warper = LogitsProcessorList( + ... [ + ... TopKLogitsWarper(50), + ... TemperatureLogitsWarper(0.7), + ... ] + ... ) + + >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) + + >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT + >>> outputs = model.sample( + ... input_ids, + ... logits_processor=logits_processor, + ... logits_warper=logits_warper, + ... stopping_criteria=stopping_criteria, + ... ) + + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the'] + ```""" + # init values + logits_processor = ( + logits_processor if logits_processor is not None else LogitsProcessorList() + ) + stopping_criteria = ( + stopping_criteria + if stopping_criteria is not None + else StoppingCriteriaList() + ) + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria( + stopping_criteria, max_length + ) + logits_warper = ( + logits_warper if logits_warper is not None else LogitsProcessorList() + ) + pad_token_id = ( + pad_token_id + if pad_token_id is not None + else self.generation_config.pad_token_id + ) + eos_token_id = ( + eos_token_id + if eos_token_id is not None + else self.generation_config.eos_token_id + ) + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + output_scores = ( + output_scores + if output_scores is not None + else self.generation_config.output_scores + ) + output_attentions = ( + output_attentions + if output_attentions is not None + else self.generation_config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.generation_config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + cross_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + decoder_hidden_states = ( + () if (return_dict_in_generate and output_hidden_states) else None + ) + + # keep track of which sequences are already finished + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + + this_peer_finished = False # used by synced_gpus only + # auto-regressive generation + while True: + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor( + 0.0 if this_peer_finished else 1.0 + ).to(input_ids.device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + break + + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + if synced_gpus and this_peer_finished: + continue # don't waste resources running the code we don't need + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) + if self.config.is_encoder_decoder + else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError( + "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." + ) + next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( + 1 - unfinished_sequences + ) + yield next_tokens, self.final_norm(outputs.hidden_states[-1][:, -1]) + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id is not None: + unfinished_sequences = unfinished_sequences.mul( + (sum(next_tokens != i for i in eos_token_id)).long() + ) + + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + if not synced_gpus: + break + else: + this_peer_finished = True + + +def init_stream_support(): + PreTrainedModel.generate = NewGenerationMixin.generate + PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream + +if __name__ == "__main__": + from transformers import PreTrainedModel + from transformers import AutoTokenizer, AutoModelForCausalLM + + PreTrainedModel.generate = NewGenerationMixin.generate + PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream + model = AutoModelForCausalLM.from_pretrained( + "bigscience/bloom-560m", torch_dtype=torch.float16 + ) + + tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m") + model = model.to("cuda:0") + model = model.eval() + prompt_text = "hello? \n" + input_ids = tokenizer( + prompt_text, return_tensors="pt", add_special_tokens=False + ).input_ids + input_ids = input_ids.to("cuda:0") + + with torch.no_grad(): + result = model.generate( + input_ids, + max_new_tokens=200, + do_sample=True, + top_k=30, + top_p=0.85, + temperature=0.35, + repetition_penalty=1.2, + early_stopping=True, + seed=0, + ) + print(tokenizer.decode(result, skip_special_tokens=True)) + generator = model.generate( + input_ids, + max_new_tokens=200, + do_sample=True, + top_k=30, + top_p=0.85, + temperature=0.35, + repetition_penalty=1.2, + early_stopping=True, + seed=0, + do_stream=True, + ) + stream_result = "" + for x in generator: + chunk = tokenizer.decode(x, skip_special_tokens=True) + stream_result += chunk + print(stream_result) \ No newline at end of file diff --git a/tortoise_tts/models/transformer.py b/tortoise_tts/models/transformer.py index 353a949..b64d4db 100644 --- a/tortoise_tts/models/transformer.py +++ b/tortoise_tts/models/transformer.py @@ -1,3 +1,5 @@ +# Copied from https://github.com/neonbjb/tortoise-tts/tree/98a891e66e7a1f11a830f31bd1ce06cc1f6a88af/tortoise/models/transformer.py + from functools import partial import torch diff --git a/tortoise_tts/models/unified_voice.py b/tortoise_tts/models/unified_voice.py index c1d6096..f61f41e 100644 --- a/tortoise_tts/models/unified_voice.py +++ b/tortoise_tts/models/unified_voice.py @@ -1,3 +1,5 @@ +# Adapted from https://github.com/neonbjb/tortoise-tts/tree/98a891e66e7a1f11a830f31bd1ce06cc1f6a88af/tortoise/models/unified_voice.py + import functools import torch @@ -14,6 +16,8 @@ from transformers import LogitsWarper from transformers import GPT2Config, GPT2Model from tqdm import tqdm +from .stream_generator import NewGenerationMixin + AVAILABLE_ATTENTIONS = ["mem_efficient", "math"] try: @@ -83,12 +87,14 @@ class ResBlock(nn.Module): def forward(self, x): return F.relu(self.net(x) + x) -class GPT2InferenceModel(GPT2PreTrainedModel): +class GPT2InferenceModel(GPT2PreTrainedModel, NewGenerationMixin): def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache=True): - super().__init__(config) + super(NewGenerationMixin, self).__init__() + super(GPT2PreTrainedModel, self).__init__(config) self.transformer = gpt self.text_pos_embedding = text_pos_emb self.embeddings = embeddings + self.final_norm = norm self.lm_head = nn.Sequential(norm, linear) self.kv_cache = kv_cache @@ -129,14 +135,14 @@ class GPT2InferenceModel(GPT2PreTrainedModel): def store_mel_emb(self, mel_emb): self.cached_mel_emb = mel_emb - def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) if not self.kv_cache: - past = None + past_key_values = None # only last token for inputs_ids if past is defined in kwargs - if past: + if past_key_values: input_ids = input_ids[:, -1].unsqueeze(-1) if token_type_ids is not None: token_type_ids = token_type_ids[:, -1].unsqueeze(-1) @@ -148,13 +154,13 @@ class GPT2InferenceModel(GPT2PreTrainedModel): # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) - if past: + if past_key_values: position_ids = position_ids[:, -1].unsqueeze(-1) else: position_ids = None return { "input_ids": input_ids, - "past_key_values": past, + "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "position_ids": position_ids, "attention_mask": attention_mask, @@ -597,6 +603,24 @@ class UnifiedVoice(nn.Module): return loss_text.mean(), loss_mel.mean(), mel_logits + def compute_embeddings( self, cond_latents, text_inputs, kv_cache = True ): + text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) + text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token) + emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + conds = cond_latents.unsqueeze(1) + emb = torch.cat([conds, emb], dim=1) + + if not hasattr(self, 'inference_model'): + # TODO: Decouple gpt_config from this inference model. + self.post_init_gpt2_config(kv_cache = kv_cache) + + self.inference_model.store_mel_emb(emb) + + embs = torch.full( ( emb.shape[0], emb.shape[1] + 1 ), fill_value=1, dtype=torch.long, device=text_inputs.device ) + embs[:, -1] = self.start_mel_token + + return embs + def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1, max_generate_length=None, typical_sampling=False, typical_mass=.9, kv_cache=True, **hf_generate_kwargs): @@ -635,6 +659,17 @@ class UnifiedVoice(nn.Module): self.inference_model.bar.close() return gen[:, trunc_index:] + def get_generator(self, inputs, max_length=500, **hf_generate_kwargs): + return self.inference_model.generate( + inputs, + bos_token_id=self.start_mel_token, + pad_token_id=self.stop_mel_token, + eos_token_id=self.stop_mel_token, + max_length=max_length, + do_stream=True, + **hf_generate_kwargs, + ) + if __name__ == '__main__': gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4) diff --git a/tortoise_tts/models/vocoder.py b/tortoise_tts/models/vocoder.py index 39ec9e0..b825ca6 100644 --- a/tortoise_tts/models/vocoder.py +++ b/tortoise_tts/models/vocoder.py @@ -1,3 +1,5 @@ +# Copied from https://github.com/neonbjb/tortoise-tts/tree/98a891e66e7a1f11a830f31bd1ce06cc1f6a88af/tortoise/models/vocoder.py + import torch import torch.nn as nn import torch.nn.functional as F diff --git a/tortoise_tts/models/xtransformers.py b/tortoise_tts/models/xtransformers.py index 4accb20..c75c9e0 100644 --- a/tortoise_tts/models/xtransformers.py +++ b/tortoise_tts/models/xtransformers.py @@ -1,3 +1,5 @@ +# Copied from https://github.com/neonbjb/tortoise-tts/tree/98a891e66e7a1f11a830f31bd1ce06cc1f6a88af/tortoise/models/xtransformers.py + import math from collections import namedtuple from functools import partial diff --git a/tortoise_tts/webui.py b/tortoise_tts/webui.py index 4969b40..2324089 100644 --- a/tortoise_tts/webui.py +++ b/tortoise_tts/webui.py @@ -95,6 +95,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): parser.add_argument("--beam-width", type=int, default=kwargs["beam-width"]) parser.add_argument("--diffusion-sampler", type=str, default=kwargs["diffusion-sampler"]) parser.add_argument("--cond-free", type=str, default=kwargs["cond-free"]) + parser.add_argument("--vocoder", type=str, default=kwargs["vocoder"].lower()) """ parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"]) parser.add_argument("--mirostat-tau", type=float, default=kwargs["mirostat-tau"]) @@ -126,6 +127,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): beam_width=args.beam_width, diffusion_sampler=args.diffusion_sampler, + vocoder_type=args.vocoder, ) wav = wav.squeeze(0).cpu().numpy() @@ -210,7 +212,7 @@ with ui: with gr.Column(scale=1): layout["inference"]["inputs"]["reference"] = gr.Audio(label="Audio Input", sources=["upload"], type="filepath") #, info="Reference audio for TTS") # layout["inference"]["stop"] = gr.Button(value="Stop") - layout["inference"]["outputs"]["output"] = gr.Audio(label="Output") + layout["inference"]["outputs"]["output"] = gr.Audio(label="Output", streaming=True) layout["inference"]["buttons"]["inference"] = gr.Button(value="Inference") with gr.Column(scale=7): with gr.Row(): @@ -221,6 +223,7 @@ with ui: with gr.Row(): layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.8, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)") layout["inference"]["inputs"]["diffusion-temp"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (Diffusion)", info="Modifies the initial noise during the diffusion pass.") + layout["inference"]["inputs"]["vocoder"] = gr.Radio( ["Vocoder", "BigVGAN", "HiFiGAN"], value="BigVGAN", label="Vocoder", type="value", info="Vocoder to use for generating the final waveform (HiFiGAN skips diffusion)." ) """ with gr.Row(): layout["inference"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")