diff --git a/data/config.yaml b/data/config.yaml
index 5aa94a7..154f310 100644
--- a/data/config.yaml
+++ b/data/config.yaml
@@ -1,3 +1,5 @@
+vocoder: "hifigan"
+
models:
- name: "autoregressive"
training: True
@@ -71,7 +73,7 @@ trainer:
backend: deepspeed
deepspeed:
- inferencing: True
+ inferencing: False
zero_optimization_level: 0
use_compression_training: False
diff --git a/tortoise_tts/__main__.py b/tortoise_tts/__main__.py
index 56441fa..c60dda6 100755
--- a/tortoise_tts/__main__.py
+++ b/tortoise_tts/__main__.py
@@ -24,6 +24,7 @@ def main():
parser.add_argument("--diffusion-sampler", type=str, default="ddim")
parser.add_argument("--cond-free", action="store_true")
+ parser.add_argument("--vocoder", type=str, default="bigvgan")
parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--device", type=str, default=None)
@@ -62,6 +63,8 @@ def main():
diffusion_sampler=args.diffusion_sampler,
cond_free=args.cond_free,
+
+ vocoder_type=args.vocoder,
)
"""
language=args.language,
diff --git a/tortoise_tts/config.py b/tortoise_tts/config.py
index 3cb4aa4..7cbaa84 100755
--- a/tortoise_tts/config.py
+++ b/tortoise_tts/config.py
@@ -526,6 +526,8 @@ class Config(BaseConfig):
sample_rate: int = 24_000
audio_backend: str = "mel"
+ vocoder: str = "bigvgan" # "vocoder" | "bigvgan" | "hifigan"
+
@property
def model(self):
for i, model in enumerate(self.models):
diff --git a/tortoise_tts/inference.py b/tortoise_tts/inference.py
index e19c431..b739321 100755
--- a/tortoise_tts/inference.py
+++ b/tortoise_tts/inference.py
@@ -5,6 +5,7 @@ import soundfile
from torch import Tensor
from einops import rearrange
from pathlib import Path
+from tqdm import tqdm
from .emb.mel import encode_from_files as encode_mel, trim, trim_random
from .utils import to_device
@@ -96,6 +97,21 @@ class TTS():
# merge inputs
return encode_mel( paths, device=self.device )
+ # taken from here https://github.com/coqui-ai/TTS/blob/d21f15cc850788f9cdf93dac0321395138665287/TTS/tts/models/xtts.py#L666
+ def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
+ """Handle chunk formatting in streaming mode"""
+ wav_chunk = wav_gen[:-overlap_len]
+ if wav_gen_prev is not None:
+ wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) : -overlap_len]
+ if wav_overlap is not None:
+ crossfade_wav = wav_chunk[:overlap_len]
+ crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device)
+ wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device)
+ wav_chunk[:overlap_len] += crossfade_wav
+ wav_overlap = wav_gen[-overlap_len:]
+ wav_gen_prev = wav_gen
+ return wav_chunk, wav_gen_prev, wav_overlap
+
@torch.inference_mode()
def inference(
self,
@@ -122,7 +138,9 @@ class TTS():
diffusion_sampler="ddim",
cond_free=True,
- out_path=None
+ vocoder_type="bigvgan",
+
+ out_path=None,
):
lines = text.split("\n")
@@ -142,7 +160,7 @@ class TTS():
diffusion = engine.module
elif "clvp" in name:
clvp = engine.module
- elif "vocoder" in name:
+ elif vocoder_type in name:
vocoder = engine.module
if autoregressive is None:
@@ -152,7 +170,7 @@ class TTS():
if clvp is None:
clvp = load_model("clvp", device=cfg.device)
if vocoder is None:
- vocoder = load_model("vocoder", device=cfg.device)
+ vocoder = load_model(vocoder_type, device=cfg.device)
autoregressive = autoregressive.to(cfg.device)
diffusion = diffusion.to(cfg.device)
@@ -183,6 +201,88 @@ class TTS():
text_tokens = pad_sequence([ text ], batch_first = True)
text_lengths = torch.Tensor([ text.shape[0] ]).to(dtype=torch.int32)
+ # streaming interface spits out the final hidden state, which HiFiGAN seems to be trained against
+ if vocoder_type == "hifigan":
+ waves = []
+ all_latents = []
+ all_codes = []
+
+ wav_gen_prev = None
+ wav_overlap = None
+ is_end = False
+ first_buffer = 60
+
+ stream_chunk_size = 40
+ overlap_wav_len = 1024
+
+ with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
+ with ml.auto_unload(autoregressive, enabled=cfg.inference.auto_unload):
+ with ml.auto_unload(vocoder, enabled=cfg.inference.auto_unload):
+ inputs = autoregressive.compute_embeddings( autoregressive_latents, text_tokens )
+
+ gpt_generator = autoregressive.get_generator(
+ inputs=inputs,
+ top_k=top_k,
+ top_p=top_p,
+ temperature=ar_temp,
+ do_sample=True,
+ num_beams=max(1, beam_width),
+ num_return_sequences=1,
+ length_penalty=length_penalty,
+ repetition_penalty=repetition_penalty,
+ output_attentions=False,
+ output_hidden_states=True,
+ )
+
+ bar = tqdm( unit="it", total=500 )
+ while not is_end:
+ try:
+ codes, latent = next(gpt_generator)
+ all_latents += [latent]
+ all_codes += [codes]
+ except StopIteration:
+ is_end = True
+
+ if is_end or (stream_chunk_size > 0 and len(all_codes) >= max(stream_chunk_size, first_buffer)):
+ first_buffer = 0
+ all_codes = []
+ bar.update( stream_chunk_size )
+
+ latents = torch.cat(all_latents, dim=0)[None, :].to(cfg.device)
+ wav_gen = vocoder.inference(latents, autoregressive_latents)
+ wav_gen = wav_gen.squeeze()
+
+ wav_chunk = wav_gen[:-overlap_wav_len]
+ if wav_gen_prev is not None:
+ wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_wav_len) : -overlap_wav_len]
+ if wav_overlap is not None:
+ crossfade_wav = wav_chunk[:overlap_wav_len]
+ crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_wav_len).to(crossfade_wav.device)
+ wav_chunk[:overlap_wav_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_wav_len).to(wav_overlap.device)
+ wav_chunk[:overlap_wav_len] += crossfade_wav
+
+ wav_overlap = wav_gen[-overlap_wav_len:]
+ wav_gen_prev = wav_gen
+
+
+ # yielding requires to do a bunch of pain to work around it turning into an async function
+ """
+ yield wav_chunk
+ """
+
+ waves.append( wav_chunk.unsqueeze(0) )
+
+ bar.close()
+
+ wav = torch.concat(waves, dim=-1)
+
+ if out_path is not None:
+ torchaudio.save( out_path, wav.cpu(), sr )
+
+ wavs.append(wav)
+
+ continue
+
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
with ml.auto_unload(autoregressive, enabled=cfg.inference.auto_unload):
# autoregressive pass
@@ -190,9 +290,11 @@ class TTS():
autoregressive_latents,
text_tokens,
do_sample=True,
+ top_k=top_k,
top_p=top_p,
temperature=ar_temp,
num_return_sequences=candidates,
+ num_beams=max(1,beam_width),
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
max_generate_length=max_ar_steps,
diff --git a/tortoise_tts/models/__init__.py b/tortoise_tts/models/__init__.py
index d5da18d..3240afc 100755
--- a/tortoise_tts/models/__init__.py
+++ b/tortoise_tts/models/__init__.py
@@ -1,5 +1,4 @@
-# https://github.com/neonbjb/tortoise-tts/tree/98a891e66e7a1f11a830f31bd1ce06cc1f6a88af/tortoise/models
-# All code under this folder is licensed as Apache License 2.0 per the original repo
+# All other ccode in this folder are licensed per the attributions at the top
from functools import cache
@@ -8,6 +7,8 @@ from .arch_utils import TorchMelSpectrogram, TacotronSTFT
from .unified_voice import UnifiedVoice
from .diffusion import DiffusionTTS
from .vocoder import UnivNetGenerator
+from .bigvgan import BigVGAN
+from .hifigan import HifiganGenerator
from .clvp import CLVP
from .dvae import DiscreteVAE
from .random_latent_generator import RandomLatentConverter
@@ -15,6 +16,8 @@ from .random_latent_generator import RandomLatentConverter
import os
import torch
from pathlib import Path
+import requests
+from tqdm import tqdm
DEFAULT_MODEL_PATH = Path(__file__).parent.parent.parent / 'data/models'
DEFAULT_MODEL_URLS = {
@@ -28,10 +31,18 @@ DEFAULT_MODEL_URLS = {
'rlg_auto.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth',
'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth',
'mel_norms.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/data/mel_norms.pth',
+
+ # BigVGAN
+ 'bigvgan_base_24khz_100band.pth': 'https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_base_24khz_100band.pth',
+ 'bigvgan_24khz_100band.pth': 'https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_24khz_100band.pth',
+
+ 'bigvgan_base_24khz_100band.json': 'https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_base_24khz_100band.json',
+ 'bigvgan_24khz_100band.json': 'https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_24khz_100band.json',
+
+ # HiFiGAN
+ 'hifigan.pth': 'https://huggingface.co/Manmay/tortoise-tts/resolve/main/hifidecoder.pth',
}
-import requests
-from tqdm import tqdm
# kludge, probably better to use HF's model downloader function
# to-do: write to a temp file then copy so downloads can be interrupted
@@ -75,6 +86,7 @@ def download_model( save_path, chunkSize = 1024, unit = "MiB" ):
@cache
def load_model(name, device="cuda", **kwargs):
load_path = None
+ config_path = None
state_dict_key = None
strict = True
@@ -95,6 +107,31 @@ def load_model(name, device="cuda", **kwargs):
elif "clvp" in name:
model = CLVP(**kwargs)
load_path = DEFAULT_MODEL_PATH / 'clvp2.pth'
+ elif "bigvgan" in name:
+ # download any JSONs (BigVGAN)
+ load_path = DEFAULT_MODEL_PATH / 'bigvgan_24khz_100band.pth'
+ config_path = load_path.with_suffix(".json")
+ if config_path.name in DEFAULT_MODEL_URLS:
+ if not config_path.exists():
+ download_model( config_path )
+ else:
+ config_path = None
+
+ model = BigVGAN(config=config_path, **kwargs)
+ state_dict_key = 'generator'
+ elif "hifigan" in name:
+ model = HifiganGenerator(
+ in_channels=1024,
+ out_channels = 1,
+ resblock_type = "1",
+ resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+ resblock_kernel_sizes = [3, 7, 11],
+ upsample_kernel_sizes = [16, 16, 4, 4],
+ upsample_initial_channel = 512,
+ upsample_factors = [8, 8, 2, 2],
+ cond_channels=1024
+ )
+ load_path = DEFAULT_MODEL_PATH / 'hifigan.pth'
elif "vocoder" in name:
model = UnivNetGenerator(**kwargs)
load_path = DEFAULT_MODEL_PATH / 'vocoder.pth'
@@ -126,6 +163,11 @@ def load_model(name, device="cuda", **kwargs):
model.eval()
+ try:
+ print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
+ except Exception as e:
+ print(f"{name}: {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
+
return model
def unload_model():
@@ -138,8 +180,6 @@ def get_model(config, training=True):
config.training = "autoregressive" in config.name
model.config = config
- print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
-
return model
def get_models(models, training=True):
diff --git a/tortoise_tts/models/arch_utils.py b/tortoise_tts/models/arch_utils.py
index 871f246..4993dd6 100644
--- a/tortoise_tts/models/arch_utils.py
+++ b/tortoise_tts/models/arch_utils.py
@@ -1,3 +1,5 @@
+# Adapted from https://github.com/neonbjb/tortoise-tts/tree/98a891e66e7a1f11a830f31bd1ce06cc1f6a88af/tortoise/models/arch_utils.py
+
import os
import functools
import math
diff --git a/tortoise_tts/models/bigvgan.py b/tortoise_tts/models/bigvgan.py
new file mode 100644
index 0000000..a3478e4
--- /dev/null
+++ b/tortoise_tts/models/bigvgan.py
@@ -0,0 +1,772 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import math
+
+import json
+import os
+import torch.utils.data
+
+from torch import nn, sin, pow
+from torch.nn import Conv1d, ConvTranspose1d, Conv2d, Parameter
+from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
+from librosa.filters import mel as librosa_mel_fn
+
+
+# filter.py
+# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
+# LICENSE is in incl_licenses directory.
+
+if 'sinc' in dir(torch):
+ sinc = torch.sinc
+else:
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
+ # https://adefossez.github.io/julius/julius/core.html
+ # LICENSE is in incl_licenses directory.
+ def sinc(x: torch.Tensor):
+ """
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
+ """
+ return torch.where(x == 0,
+ torch.tensor(1., device=x.device, dtype=x.dtype),
+ torch.sin(math.pi * x) / math.pi / x)
+
+
+# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
+# https://adefossez.github.io/julius/julius/lowpass.html
+# LICENSE is in incl_licenses directory.
+def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
+ even = (kernel_size % 2 == 0)
+ half_size = kernel_size // 2
+
+ #For kaiser window
+ delta_f = 4 * half_width
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
+ if A > 50.:
+ beta = 0.1102 * (A - 8.7)
+ elif A >= 21.:
+ beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
+ else:
+ beta = 0.
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
+
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
+ if even:
+ time = (torch.arange(-half_size, half_size) + 0.5)
+ else:
+ time = torch.arange(kernel_size) - half_size
+ if cutoff == 0:
+ filter_ = torch.zeros_like(time)
+ else:
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
+ # Normalize filter to have sum = 1, otherwise we will have a small leakage
+ # of the constant component in the input signal.
+ filter_ /= filter_.sum()
+ filter = filter_.view(1, 1, kernel_size)
+
+ return filter
+
+
+class LowPassFilter1d(nn.Module):
+ def __init__(self,
+ cutoff=0.5,
+ half_width=0.6,
+ stride: int = 1,
+ padding: bool = True,
+ padding_mode: str = 'replicate',
+ kernel_size: int = 12):
+ # kernel_size should be even number for stylegan3 setup,
+ # in this implementation, odd number is also possible.
+ super().__init__()
+ if cutoff < -0.:
+ raise ValueError("Minimum cutoff must be larger than zero.")
+ if cutoff > 0.5:
+ raise ValueError("A cutoff above 0.5 does not make sense.")
+ self.kernel_size = kernel_size
+ self.even = (kernel_size % 2 == 0)
+ self.pad_left = kernel_size // 2 - int(self.even)
+ self.pad_right = kernel_size // 2
+ self.stride = stride
+ self.padding = padding
+ self.padding_mode = padding_mode
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
+ self.register_buffer("filter", filter)
+
+ #input [B, C, T]
+ def forward(self, x):
+ _, C, _ = x.shape
+
+ if self.padding:
+ x = F.pad(x, (self.pad_left, self.pad_right),
+ mode=self.padding_mode)
+ out = F.conv1d(x, self.filter.expand(C, -1, -1),
+ stride=self.stride, groups=C)
+
+ return out
+
+# resample.py
+# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
+# LICENSE is in incl_licenses directory.
+
+class UpSample1d(nn.Module):
+ def __init__(self, ratio=2, kernel_size=None):
+ super().__init__()
+ self.ratio = ratio
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
+ self.stride = ratio
+ self.pad = self.kernel_size // ratio - 1
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
+ self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
+ filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
+ half_width=0.6 / ratio,
+ kernel_size=self.kernel_size)
+ self.register_buffer("filter", filter)
+
+ # x: [B, C, T]
+ def forward(self, x):
+ _, C, _ = x.shape
+
+ x = F.pad(x, (self.pad, self.pad), mode='replicate')
+ x = self.ratio * F.conv_transpose1d(
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
+ x = x[..., self.pad_left:-self.pad_right]
+
+ return x
+
+
+class DownSample1d(nn.Module):
+ def __init__(self, ratio=2, kernel_size=None):
+ super().__init__()
+ self.ratio = ratio
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
+ self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
+ half_width=0.6 / ratio,
+ stride=ratio,
+ kernel_size=self.kernel_size)
+
+ def forward(self, x):
+ xx = self.lowpass(x)
+
+ return xx
+
+# act.py
+# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
+# LICENSE is in incl_licenses directory.
+
+class Activation1d(nn.Module):
+ def __init__(self,
+ activation,
+ up_ratio: int = 2,
+ down_ratio: int = 2,
+ up_kernel_size: int = 12,
+ down_kernel_size: int = 12):
+ super().__init__()
+ self.up_ratio = up_ratio
+ self.down_ratio = down_ratio
+ self.act = activation
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
+
+ # x: [B,C,T]
+ def forward(self, x):
+ x = self.upsample(x)
+ x = self.act(x)
+ x = self.downsample(x)
+
+ return x
+
+# activations.py
+# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
+# LICENSE is in incl_licenses directory.
+
+class Snake(nn.Module):
+ '''
+ Implementation of a sine-based periodic activation function
+ Shape:
+ - Input: (B, C, T)
+ - Output: (B, C, T), same shape as the input
+ Parameters:
+ - alpha - trainable parameter
+ References:
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
+ https://arxiv.org/abs/2006.08195
+ Examples:
+ >>> a1 = snake(256)
+ >>> x = torch.randn(256)
+ >>> x = a1(x)
+ '''
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
+ '''
+ Initialization.
+ INPUT:
+ - in_features: shape of the input
+ - alpha: trainable parameter
+ alpha is initialized to 1 by default, higher values = higher-frequency.
+ alpha will be trained along with the rest of your model.
+ '''
+ super(Snake, self).__init__()
+ self.in_features = in_features
+
+ # initialize alpha
+ self.alpha_logscale = alpha_logscale
+ if self.alpha_logscale: # log scale alphas initialized to zeros
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
+ else: # linear scale alphas initialized to ones
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
+
+ self.alpha.requires_grad = alpha_trainable
+
+ self.no_div_by_zero = 0.000000001
+
+ def forward(self, x):
+ '''
+ Forward pass of the function.
+ Applies the function to the input elementwise.
+ Snake ∶= x + 1/a * sin^2 (xa)
+ '''
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
+ if self.alpha_logscale:
+ alpha = torch.exp(alpha)
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
+
+ return x
+
+
+class SnakeBeta(nn.Module):
+ '''
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
+ Shape:
+ - Input: (B, C, T)
+ - Output: (B, C, T), same shape as the input
+ Parameters:
+ - alpha - trainable parameter that controls frequency
+ - beta - trainable parameter that controls magnitude
+ References:
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
+ https://arxiv.org/abs/2006.08195
+ Examples:
+ >>> a1 = snakebeta(256)
+ >>> x = torch.randn(256)
+ >>> x = a1(x)
+ '''
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
+ '''
+ Initialization.
+ INPUT:
+ - in_features: shape of the input
+ - alpha - trainable parameter that controls frequency
+ - beta - trainable parameter that controls magnitude
+ alpha is initialized to 1 by default, higher values = higher-frequency.
+ beta is initialized to 1 by default, higher values = higher-magnitude.
+ alpha will be trained along with the rest of your model.
+ '''
+ super(SnakeBeta, self).__init__()
+ self.in_features = in_features
+
+ # initialize alpha
+ self.alpha_logscale = alpha_logscale
+ if self.alpha_logscale: # log scale alphas initialized to zeros
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
+ else: # linear scale alphas initialized to ones
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
+ self.beta = Parameter(torch.ones(in_features) * alpha)
+
+ self.alpha.requires_grad = alpha_trainable
+ self.beta.requires_grad = alpha_trainable
+
+ self.no_div_by_zero = 0.000000001
+
+ def forward(self, x):
+ '''
+ Forward pass of the function.
+ Applies the function to the input elementwise.
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
+ '''
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
+ if self.alpha_logscale:
+ alpha = torch.exp(alpha)
+ beta = torch.exp(beta)
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
+
+ return x
+
+# bigvgan.py
+# Copyright (c) 2022 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
+# LICENSE is in incl_licenses directory.
+
+LRELU_SLOPE = 0.1
+
+class AMPBlock1(torch.nn.Module):
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
+ super(AMPBlock1, self).__init__()
+ self.h = h
+
+ self.convs1 = nn.ModuleList([
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
+ padding=get_padding(kernel_size, dilation[2])))
+ ])
+ self.convs1.apply(init_weights)
+
+ self.convs2 = nn.ModuleList([
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+ padding=get_padding(kernel_size, 1))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+ padding=get_padding(kernel_size, 1))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+ padding=get_padding(kernel_size, 1)))
+ ])
+ self.convs2.apply(init_weights)
+
+ self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
+
+ if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
+ self.activations = nn.ModuleList([
+ Activation1d(
+ activation=Snake(channels, alpha_logscale=h.snake_logscale))
+ for _ in range(self.num_layers)
+ ])
+ elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
+ self.activations = nn.ModuleList([
+ Activation1d(
+ activation=SnakeBeta(channels, alpha_logscale=h.snake_logscale))
+ for _ in range(self.num_layers)
+ ])
+ else:
+ raise NotImplementedError(
+ "activation incorrectly specified. check the config file and look for 'activation'.")
+
+ def forward(self, x):
+ acts1, acts2 = self.activations[::2], self.activations[1::2]
+ for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
+ xt = a1(x)
+ xt = c1(xt)
+ xt = a2(xt)
+ xt = c2(xt)
+ x = xt + x
+
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs1:
+ remove_weight_norm(l)
+ for l in self.convs2:
+ remove_weight_norm(l)
+
+
+class AMPBlock2(torch.nn.Module):
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None):
+ super(AMPBlock2, self).__init__()
+ self.h = h
+
+ self.convs = nn.ModuleList([
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1])))
+ ])
+ self.convs.apply(init_weights)
+
+ self.num_layers = len(self.convs) # total number of conv layers
+
+ if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
+ self.activations = nn.ModuleList([
+ Activation1d(
+ activation=Snake(channels, alpha_logscale=h.snake_logscale))
+ for _ in range(self.num_layers)
+ ])
+ elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
+ self.activations = nn.ModuleList([
+ Activation1d(
+ activation=SnakeBeta(channels, alpha_logscale=h.snake_logscale))
+ for _ in range(self.num_layers)
+ ])
+ else:
+ raise NotImplementedError(
+ "activation incorrectly specified. check the config file and look for 'activation'.")
+
+ def forward(self, x):
+ for c, a in zip(self.convs, self.activations):
+ xt = a(x)
+ xt = c(xt)
+ x = xt + x
+
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs:
+ remove_weight_norm(l)
+
+
+
+class AttrDict(dict):
+ def __init__(self, *args, **kwargs):
+ super(AttrDict, self).__init__(*args, **kwargs)
+ self.__dict__ = self
+
+class BigVGAN(nn.Module):
+ # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
+ def __init__(self, config=None, data=None):
+ super(BigVGAN, self).__init__()
+
+ """
+ with open(os.path.join(os.path.dirname(__file__), 'config.json'), 'r') as f:
+ data = f.read()
+ """
+ if config and data is None:
+ with open(config, 'r') as f:
+ data = f.read()
+ jsonConfig = json.loads(data)
+ elif data is not None:
+ if isinstance(data, str):
+ jsonConfig = json.loads(data)
+ else:
+ jsonConfig = data
+ else:
+ raise Exception("no config specified")
+
+
+ global h
+ h = AttrDict(jsonConfig)
+
+ self.mel_channel = h.num_mels
+ self.noise_dim = h.n_fft
+ self.hop_length = h.hop_size
+ self.num_kernels = len(h.resblock_kernel_sizes)
+ self.num_upsamples = len(h.upsample_rates)
+
+ # pre conv
+ self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
+
+ # define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
+ resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2
+
+ # transposed conv-based upsamplers. does not apply anti-aliasing
+ self.ups = nn.ModuleList()
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
+ self.ups.append(nn.ModuleList([
+ weight_norm(ConvTranspose1d(h.upsample_initial_channel // (2 ** i),
+ h.upsample_initial_channel // (2 ** (i + 1)),
+ k, u, padding=(k - u) // 2))
+ ]))
+
+ # residual blocks using anti-aliased multi-periodicity composition modules (AMP)
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
+ for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
+ self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
+
+ # post conv
+ if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
+ activation_post = Snake(ch, alpha_logscale=h.snake_logscale)
+ self.activation_post = Activation1d(activation=activation_post)
+ elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
+ activation_post = SnakeBeta(ch, alpha_logscale=h.snake_logscale)
+ self.activation_post = Activation1d(activation=activation_post)
+ else:
+ raise NotImplementedError(
+ "activation incorrectly specified. check the config file and look for 'activation'.")
+
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
+
+ # weight initialization
+ for i in range(len(self.ups)):
+ self.ups[i].apply(init_weights)
+ self.conv_post.apply(init_weights)
+
+ def forward(self,x, c):
+ # pre conv
+ x = self.conv_pre(x)
+
+ for i in range(self.num_upsamples):
+ # upsampling
+ for i_up in range(len(self.ups[i])):
+ x = self.ups[i][i_up](x)
+ # AMP blocks
+ xs = None
+ for j in range(self.num_kernels):
+ if xs is None:
+ xs = self.resblocks[i * self.num_kernels + j](x)
+ else:
+ xs += self.resblocks[i * self.num_kernels + j](x)
+ x = xs / self.num_kernels
+
+ # post conv
+ x = self.activation_post(x)
+ x = self.conv_post(x)
+ x = torch.tanh(x)
+
+ return x
+
+ def remove_weight_norm(self):
+ print('Removing weight norm...')
+ for l in self.ups:
+ for l_i in l:
+ remove_weight_norm(l_i)
+ for l in self.resblocks:
+ l.remove_weight_norm()
+ remove_weight_norm(self.conv_pre)
+ remove_weight_norm(self.conv_post)
+
+ def inference(self, c, z=None):
+ # pad input mel with zeros to cut artifact
+ # see https://github.com/seungwonpark/melgan/issues/8
+ zero = torch.full((c.shape[0], h.num_mels, 10), -11.5129).to(c.device)
+ mel = torch.cat((c, zero), dim=2)
+
+ if z is None:
+ z = torch.randn(c.shape[0], self.noise_dim, mel.size(2)).to(mel.device)
+
+ audio = self.forward(mel, z)
+ audio = audio[:, :, :-(self.hop_length * 10)]
+ audio = audio.clamp(min=-1, max=1)
+ return audio
+
+ def eval(self, inference=False):
+ super(BigVGAN, self).eval()
+ # don't remove weight norm while validation in training loop
+ if inference:
+ self.remove_weight_norm()
+
+
+class DiscriminatorP(nn.Module):
+ def __init__(self, h, period, kernel_size=5, stride=3, use_spectral_norm=False):
+ super(DiscriminatorP, self).__init__()
+ self.period = period
+ self.d_mult = h.discriminator_channel_mult
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+ self.convs = nn.ModuleList([
+ norm_f(Conv2d(1, int(32 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
+ norm_f(Conv2d(int(32 * self.d_mult), int(128 * self.d_mult), (kernel_size, 1), (stride, 1),
+ padding=(get_padding(5, 1), 0))),
+ norm_f(Conv2d(int(128 * self.d_mult), int(512 * self.d_mult), (kernel_size, 1), (stride, 1),
+ padding=(get_padding(5, 1), 0))),
+ norm_f(Conv2d(int(512 * self.d_mult), int(1024 * self.d_mult), (kernel_size, 1), (stride, 1),
+ padding=(get_padding(5, 1), 0))),
+ norm_f(Conv2d(int(1024 * self.d_mult), int(1024 * self.d_mult), (kernel_size, 1), 1, padding=(2, 0))),
+ ])
+ self.conv_post = norm_f(Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0)))
+
+ def forward(self, x):
+ fmap = []
+
+ # 1d to 2d
+ b, c, t = x.shape
+ if t % self.period != 0: # pad first
+ n_pad = self.period - (t % self.period)
+ x = F.pad(x, (0, n_pad), "reflect")
+ t = t + n_pad
+ x = x.view(b, c, t // self.period, self.period)
+
+ for l in self.convs:
+ x = l(x)
+ x = F.leaky_relu(x, LRELU_SLOPE)
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+ x = torch.flatten(x, 1, -1)
+
+ return x, fmap
+
+
+class MultiPeriodDiscriminator(nn.Module):
+ def __init__(self, h):
+ super(MultiPeriodDiscriminator, self).__init__()
+ self.mpd_reshapes = h.mpd_reshapes
+ print("mpd_reshapes: {}".format(self.mpd_reshapes))
+ discriminators = [DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm) for rs in self.mpd_reshapes]
+ self.discriminators = nn.ModuleList(discriminators)
+
+ def forward(self, y, y_hat):
+ y_d_rs = []
+ y_d_gs = []
+ fmap_rs = []
+ fmap_gs = []
+ for i, d in enumerate(self.discriminators):
+ y_d_r, fmap_r = d(y)
+ y_d_g, fmap_g = d(y_hat)
+ y_d_rs.append(y_d_r)
+ fmap_rs.append(fmap_r)
+ y_d_gs.append(y_d_g)
+ fmap_gs.append(fmap_g)
+
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+class DiscriminatorR(nn.Module):
+ def __init__(self, cfg, resolution):
+ super().__init__()
+
+ self.resolution = resolution
+ assert len(self.resolution) == 3, \
+ "MRD layer requires list with len=3, got {}".format(self.resolution)
+ self.lrelu_slope = LRELU_SLOPE
+
+ norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm
+ if hasattr(cfg, "mrd_use_spectral_norm"):
+ print("INFO: overriding MRD use_spectral_norm as {}".format(cfg.mrd_use_spectral_norm))
+ norm_f = weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
+ self.d_mult = cfg.discriminator_channel_mult
+ if hasattr(cfg, "mrd_channel_mult"):
+ print("INFO: overriding mrd channel multiplier as {}".format(cfg.mrd_channel_mult))
+ self.d_mult = cfg.mrd_channel_mult
+
+ self.convs = nn.ModuleList([
+ norm_f(nn.Conv2d(1, int(32 * self.d_mult), (3, 9), padding=(1, 4))),
+ norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
+ norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
+ norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
+ norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 3), padding=(1, 1))),
+ ])
+ self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1)))
+
+ def forward(self, x):
+ fmap = []
+
+ x = self.spectrogram(x)
+ x = x.unsqueeze(1)
+ for l in self.convs:
+ x = l(x)
+ x = F.leaky_relu(x, self.lrelu_slope)
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+ x = torch.flatten(x, 1, -1)
+
+ return x, fmap
+
+ def spectrogram(self, x):
+ n_fft, hop_length, win_length = self.resolution
+ x = F.pad(x, (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), mode='reflect')
+ x = x.squeeze(1)
+ x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False, return_complex=True)
+ x = torch.view_as_real(x) # [B, F, TT, 2]
+ mag = torch.norm(x, p=2, dim=-1) # [B, F, TT]
+
+ return mag
+
+
+class MultiResolutionDiscriminator(nn.Module):
+ def __init__(self, cfg, debug=False):
+ super().__init__()
+ self.resolutions = cfg.resolutions
+ assert len(self.resolutions) == 3, \
+ "MRD requires list of list with len=3, each element having a list with len=3. got {}". \
+ format(self.resolutions)
+ self.discriminators = nn.ModuleList(
+ [DiscriminatorR(cfg, resolution) for resolution in self.resolutions]
+ )
+
+ def forward(self, y, y_hat):
+ y_d_rs = []
+ y_d_gs = []
+ fmap_rs = []
+ fmap_gs = []
+
+ for i, d in enumerate(self.discriminators):
+ y_d_r, fmap_r = d(x=y)
+ y_d_g, fmap_g = d(x=y_hat)
+ y_d_rs.append(y_d_r)
+ fmap_rs.append(fmap_r)
+ y_d_gs.append(y_d_g)
+ fmap_gs.append(fmap_g)
+
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+def get_mel(x):
+ return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)
+
+def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
+ if torch.min(y) < -1.:
+ print('min value is ', torch.min(y))
+ if torch.max(y) > 1.:
+ print('max value is ', torch.max(y))
+
+ global mel_basis, hann_window
+ if fmax not in mel_basis:
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
+ mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
+
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
+ y = y.squeeze(1)
+
+ # complex tensor as default, then use view_as_real for future pytorch compatibility
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
+ spec = torch.view_as_real(spec)
+ spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
+
+ spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
+ spec = torch.nn.utils.spectral_normalize_torch(spec)
+
+ return spec
+
+def feature_loss(fmap_r, fmap_g):
+ loss = 0
+ for dr, dg in zip(fmap_r, fmap_g):
+ for rl, gl in zip(dr, dg):
+ loss += torch.mean(torch.abs(rl - gl))
+
+ return loss * 2
+
+
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+ return int((kernel_size * dilation - dilation) / 2)
+
+
+def discriminator_loss(disc_real_outputs, disc_generated_outputs):
+ loss = 0
+ r_losses = []
+ g_losses = []
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
+ r_loss = torch.mean((1 - dr) ** 2)
+ g_loss = torch.mean(dg ** 2)
+ loss += (r_loss + g_loss)
+ r_losses.append(r_loss.item())
+ g_losses.append(g_loss.item())
+
+ return loss, r_losses, g_losses
+
+
+def generator_loss(disc_outputs):
+ loss = 0
+ gen_losses = []
+ for dg in disc_outputs:
+ l = torch.mean((1 - dg) ** 2)
+ gen_losses.append(l)
+ loss += l
+
+ return loss, gen_losses
+
+
+if __name__ == '__main__':
+ model = BigVGAN()
+
+ c = torch.randn(3, 100, 10)
+ z = torch.randn(3, 64, 10)
+ print(c.shape)
+
+ y = model(c, z)
+ print(y.shape)
+ assert y.shape == torch.Size([3, 1, 2560])
+
+ pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ print(pytorch_total_params)
\ No newline at end of file
diff --git a/tortoise_tts/models/classifier.py b/tortoise_tts/models/classifier.py
index cf00f3a..10331f2 100644
--- a/tortoise_tts/models/classifier.py
+++ b/tortoise_tts/models/classifier.py
@@ -1,3 +1,5 @@
+# Copied from https://github.com/neonbjb/tortoise-tts/tree/98a891e66e7a1f11a830f31bd1ce06cc1f6a88af/tortoise/models/classifier.py
+
import torch
import torch.nn as nn
diff --git a/tortoise_tts/models/clvp.py b/tortoise_tts/models/clvp.py
index e6c43c4..b2ccda8 100644
--- a/tortoise_tts/models/clvp.py
+++ b/tortoise_tts/models/clvp.py
@@ -1,3 +1,5 @@
+# Copied from https://github.com/neonbjb/tortoise-tts/tree/98a891e66e7a1f11a830f31bd1ce06cc1f6a88af/tortoise/models/clvp.py
+
import torch
import torch.nn as nn
import torch.nn.functional as F
diff --git a/tortoise_tts/models/diffusion.py b/tortoise_tts/models/diffusion.py
index 00c7c1e..c0137a3 100644
--- a/tortoise_tts/models/diffusion.py
+++ b/tortoise_tts/models/diffusion.py
@@ -1,3 +1,5 @@
+# Adapted from https://github.com/neonbjb/tortoise-tts/tree/98a891e66e7a1f11a830f31bd1ce06cc1f6a88af/tortoise/models/diffusion.py
+
import enum
import math
import random
diff --git a/tortoise_tts/models/hifigan.py b/tortoise_tts/models/hifigan.py
new file mode 100644
index 0000000..0883b3c
--- /dev/null
+++ b/tortoise_tts/models/hifigan.py
@@ -0,0 +1,305 @@
+# Grabbed from https://git.ecker.tech/Jarod/tortoise-tts/src/branch/main
+# Adapted from https://github.com/jik876/hifi-gan/blob/master/models.py
+
+import torch
+from torch import nn
+from torch.nn import Conv1d, ConvTranspose1d
+from torch.nn import functional as F
+from torch.nn.utils import remove_weight_norm, weight_norm
+
+LRELU_SLOPE = 0.1
+
+
+def get_padding(k, d):
+ return int((k * d - d) / 2)
+
+
+class ResBlock1(torch.nn.Module):
+ """Residual Block Type 1. It has 3 convolutional layers in each convolutional block.
+
+ Network::
+
+ x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o
+ |--------------------------------------------------------------------------------------------------|
+
+
+ Args:
+ channels (int): number of hidden channels for the convolutional layers.
+ kernel_size (int): size of the convolution filter in each layer.
+ dilations (list): list of dilation value for each conv layer in a block.
+ """
+
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
+ super().__init__()
+ self.convs1 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[2],
+ padding=get_padding(kernel_size, dilation[2]),
+ )
+ ),
+ ]
+ )
+
+ self.convs2 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
+ ),
+ weight_norm(
+ Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
+ ),
+ weight_norm(
+ Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
+ ),
+ ]
+ )
+
+ def forward(self, x):
+ """
+ Args:
+ x (Tensor): input tensor.
+ Returns:
+ Tensor: output tensor.
+ Shapes:
+ x: [B, C, T]
+ """
+ for c1, c2 in zip(self.convs1, self.convs2):
+ xt = F.leaky_relu(x, LRELU_SLOPE)
+ xt = c1(xt)
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
+ xt = c2(xt)
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs1:
+ remove_weight_norm(l)
+ for l in self.convs2:
+ remove_weight_norm(l)
+
+
+class ResBlock2(torch.nn.Module):
+ """Residual Block Type 2. It has 1 convolutional layers in each convolutional block.
+
+ Network::
+
+ x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o
+ |---------------------------------------------------|
+
+
+ Args:
+ channels (int): number of hidden channels for the convolutional layers.
+ kernel_size (int): size of the convolution filter in each layer.
+ dilations (list): list of dilation value for each conv layer in a block.
+ """
+
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
+ super().__init__()
+ self.convs = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]),
+ )
+ ),
+ ]
+ )
+
+ def forward(self, x):
+ for c in self.convs:
+ xt = F.leaky_relu(x, LRELU_SLOPE)
+ xt = c(xt)
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs:
+ remove_weight_norm(l)
+
+
+class HifiganGenerator(torch.nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ resblock_type,
+ resblock_dilation_sizes,
+ resblock_kernel_sizes,
+ upsample_kernel_sizes,
+ upsample_initial_channel,
+ upsample_factors,
+ inference_padding=5,
+ cond_channels=0,
+ conv_pre_weight_norm=True,
+ conv_post_weight_norm=True,
+ conv_post_bias=True,
+ ):
+ r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF)
+
+ Network:
+ x -> lrelu -> upsampling_layer -> resblock1_k1x1 -> z1 -> + -> z_sum / #resblocks -> lrelu -> conv_post_7x1 -> tanh -> o
+ .. -> zI ---|
+ resblockN_kNx1 -> zN ---'
+
+ Args:
+ in_channels (int): number of input tensor channels.
+ out_channels (int): number of output tensor channels.
+ resblock_type (str): type of the `ResBlock`. '1' or '2'.
+ resblock_dilation_sizes (List[List[int]]): list of dilation values in each layer of a `ResBlock`.
+ resblock_kernel_sizes (List[int]): list of kernel sizes for each `ResBlock`.
+ upsample_kernel_sizes (List[int]): list of kernel sizes for each transposed convolution.
+ upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2
+ for each consecutive upsampling layer.
+ upsample_factors (List[int]): upsampling factors (stride) for each upsampling layer.
+ inference_padding (int): constant padding applied to the input at inference time. Defaults to 5.
+ """
+ super().__init__()
+ self.inference_padding = inference_padding
+ self.num_kernels = len(resblock_kernel_sizes)
+ self.num_upsamples = len(upsample_factors)
+ # initial upsampling layers
+ self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3))
+ resblock = ResBlock1 if resblock_type == "1" else ResBlock2
+ # upsampling layers
+ self.ups = nn.ModuleList()
+ for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)):
+ self.ups.append(
+ weight_norm(
+ ConvTranspose1d(
+ upsample_initial_channel // (2**i),
+ upsample_initial_channel // (2 ** (i + 1)),
+ k,
+ u,
+ padding=(k - u) // 2,
+ )
+ )
+ )
+ # MRF blocks
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = upsample_initial_channel // (2 ** (i + 1))
+ for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
+ self.resblocks.append(resblock(ch, k, d))
+ # post convolution layer
+ self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias))
+ if cond_channels > 0:
+ self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
+
+ if not conv_pre_weight_norm:
+ remove_weight_norm(self.conv_pre)
+
+ if not conv_post_weight_norm:
+ remove_weight_norm(self.conv_post)
+
+ self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu')
+ if torch.backends.mps.is_available():
+ self.device = torch.device('mps')
+
+ def forward(self, x, g=None):
+ """
+ Args:
+ x (Tensor): feature input tensor.
+ g (Tensor): global conditioning input tensor.
+
+ Returns:
+ Tensor: output waveform.
+
+ Shapes:
+ x: [B, C, T]
+ Tensor: [B, 1, T]
+ """
+ o = self.conv_pre(x)
+ if hasattr(self, "cond_layer"):
+ o = o + self.cond_layer(g)
+ for i in range(self.num_upsamples):
+ o = F.leaky_relu(o, LRELU_SLOPE)
+ o = self.ups[i](o)
+ z_sum = None
+ for j in range(self.num_kernels):
+ if z_sum is None:
+ z_sum = self.resblocks[i * self.num_kernels + j](o)
+ else:
+ z_sum += self.resblocks[i * self.num_kernels + j](o)
+ o = z_sum / self.num_kernels
+ o = F.leaky_relu(o)
+ o = self.conv_post(o)
+ o = torch.tanh(o)
+ return o
+
+ @torch.no_grad()
+ def inference(self, c, g=None):
+ """
+ Args:
+ x (Tensor): conditioning input tensor.
+
+ Returns:
+ Tensor: output waveform.
+
+ Shapes:
+ x: [B, C, T]
+ Tensor: [B, 1, T]
+ """
+ # c = c.to(self.conv_pre.weight.device)
+ # c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
+ up_1 = torch.nn.functional.interpolate(
+ c.transpose(1,2),
+ scale_factor=[1024 / 256],
+ mode="linear",
+ )
+ up_2 = torch.nn.functional.interpolate(
+ up_1,
+ scale_factor=[24000 / 22050],
+ mode="linear",
+ )
+ g = g.unsqueeze(0)
+ return self.forward(up_2.to(self.device), g.transpose(1,2))
+
+ def remove_weight_norm(self):
+ print("Removing weight norm...")
+ for l in self.ups:
+ remove_weight_norm(l)
+ for l in self.resblocks:
+ l.remove_weight_norm()
+ remove_weight_norm(self.conv_pre)
+ remove_weight_norm(self.conv_post)
\ No newline at end of file
diff --git a/tortoise_tts/models/lora.py b/tortoise_tts/models/lora.py
index 30ca4c3..cee0605 100644
--- a/tortoise_tts/models/lora.py
+++ b/tortoise_tts/models/lora.py
@@ -1,4 +1,5 @@
# Adapted from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
+
from functools import partial
import torch
import torch.nn.functional as F
diff --git a/tortoise_tts/models/random_latent_generator.py b/tortoise_tts/models/random_latent_generator.py
index 4d6ff80..6fa4494 100644
--- a/tortoise_tts/models/random_latent_generator.py
+++ b/tortoise_tts/models/random_latent_generator.py
@@ -1,3 +1,6 @@
+# Copied from https://github.com/neonbjb/tortoise-tts/tree/98a891e66e7a1f11a830f31bd1ce06cc1f6a88af/tortoise/models/random_latent_generator.py
+
+
import math
import torch
diff --git a/tortoise_tts/models/stream_generator.py b/tortoise_tts/models/stream_generator.py
new file mode 100644
index 0000000..5f41407
--- /dev/null
+++ b/tortoise_tts/models/stream_generator.py
@@ -0,0 +1,1064 @@
+# Adapted from https://git.ecker.tech/Jarod/tortoise-tts/raw/commit/156bb5e7da98128071ab94fd4f3884e64566f3dc/tortoise/models/stream_generator.py
+# Adapted from: https://github.com/LowinLi/transformers-stream-generator
+
+from transformers import (
+ GenerationConfig,
+ GenerationMixin,
+ LogitsProcessorList,
+ StoppingCriteriaList,
+ DisjunctiveConstraint,
+ BeamSearchScorer,
+ PhrasalConstraint,
+ ConstrainedBeamSearchScorer,
+ PreTrainedModel,
+)
+import numpy as np
+import random
+import warnings
+import inspect
+from transformers.generation.utils import GenerateOutput, SampleOutput, logger
+import torch
+from typing import Callable, List, Optional, Union
+from torch import nn
+import torch.distributed as dist
+import copy
+
+
+def setup_seed(seed):
+ if seed == -1:
+ return
+ torch.manual_seed(seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ torch.backends.cudnn.deterministic = True
+
+
+class StreamGenerationConfig(GenerationConfig):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.do_stream = kwargs.pop("do_stream", False)
+
+
+class NewGenerationMixin(GenerationMixin):
+ @torch.no_grad()
+ def generate(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ generation_config: Optional[StreamGenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ prefix_allowed_tokens_fn: Optional[
+ Callable[[int, torch.Tensor], List[int]]
+ ] = None,
+ synced_gpus: Optional[bool] = False,
+ seed=0,
+ **kwargs,
+ ) -> Union[GenerateOutput, torch.LongTensor]:
+ r"""
+
+ Generates sequences of token ids for models with a language modeling head.
+
+
+
+ Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
+ model's default generation configuration. You can override any `generation_config` by passing the corresponding
+ parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
+
+ For an overview of generation strategies and code examples, check out the [following
+ guide](./generation_strategies).
+
+
+
+ Parameters:
+ inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
+ The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
+ method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
+ should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
+ `input_ids`, `input_values`, `input_features`, or `pixel_values`.
+ generation_config (`~generation.GenerationConfig`, *optional*):
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
+ passed to generate matching the attributes of `generation_config` will override them. If
+ `generation_config` is not provided, the default will be used, which had the following loading
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
+ default values, whose documentation should be checked to parameterize generation.
+ logits_processor (`LogitsProcessorList`, *optional*):
+ Custom logits processors that complement the default logits processors built from arguments and
+ generation config. If a logit processor is passed that is already created with the arguments or a
+ generation config an error is thrown. This feature is intended for advanced users.
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
+ Custom stopping criteria that complement the default stopping criteria built from arguments and a
+ generation config. If a stopping criteria is passed that is already created with the arguments or a
+ generation config an error is thrown. This feature is intended for advanced users.
+ prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
+ If provided, this function constraints the beam search to allowed tokens only at each step. If not
+ provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
+ `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
+ on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
+ for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
+ Retrieval](https://arxiv.org/abs/2010.00904).
+ synced_gpus (`bool`, *optional*, defaults to `False`):
+ Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
+ kwargs:
+ Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
+ forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
+ specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
+
+ Return:
+ [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
+ or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
+
+ If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
+ [`~utils.ModelOutput`] types are:
+
+ - [`~generation.GreedySearchDecoderOnlyOutput`],
+ - [`~generation.SampleDecoderOnlyOutput`],
+ - [`~generation.BeamSearchDecoderOnlyOutput`],
+ - [`~generation.BeamSampleDecoderOnlyOutput`]
+
+ If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
+ [`~utils.ModelOutput`] types are:
+
+ - [`~generation.GreedySearchEncoderDecoderOutput`],
+ - [`~generation.SampleEncoderDecoderOutput`],
+ - [`~generation.BeamSearchEncoderDecoderOutput`],
+ - [`~generation.BeamSampleEncoderDecoderOutput`]
+ """
+ setup_seed(seed)
+ # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
+ self._validate_model_class()
+
+ # priority: `generation_config` argument > `model.generation_config` (the default generation config)
+ if generation_config is None:
+ # legacy: users may modify the model configuration to control generation -- update the generation config
+ # model attribute accordingly, if it was created from the model config
+ if self.generation_config._from_model_config:
+ new_generation_config = StreamGenerationConfig.from_model_config(
+ self.config
+ )
+ if new_generation_config != self.generation_config:
+ warnings.warn(
+ "You have modified the pretrained model configuration to control generation. This is a"
+ " deprecated strategy to control generation and will be removed soon, in a future version."
+ " Please use a generation configuration file (see"
+ " https://huggingface.co/docs/transformers/main_classes/text_generation)"
+ )
+ self.generation_config = new_generation_config
+ generation_config = self.generation_config
+
+ generation_config = copy.deepcopy(generation_config)
+ model_kwargs = generation_config.update(
+ **kwargs
+ ) # All unused kwargs must be model kwargs
+ # self._validate_model_kwargs(model_kwargs.copy())
+
+ # 2. Set generation parameters if not already defined
+ logits_processor = (
+ logits_processor if logits_processor is not None else LogitsProcessorList()
+ )
+ stopping_criteria = (
+ stopping_criteria
+ if stopping_criteria is not None
+ else StoppingCriteriaList()
+ )
+
+ if (
+ generation_config.pad_token_id is None
+ and generation_config.eos_token_id is not None
+ ):
+ if model_kwargs.get("attention_mask", None) is None:
+ logger.warning(
+ "The attention mask and the pad token id were not set. As a consequence, you may observe "
+ "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
+ )
+ eos_token_id = generation_config.eos_token_id
+ if isinstance(eos_token_id, list):
+ eos_token_id = eos_token_id[0]
+ logger.warning(
+ f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation."
+ )
+ generation_config.pad_token_id = eos_token_id
+
+ # 3. Define model inputs
+ # inputs_tensor has to be defined
+ # model_input_name is defined if model-specific keyword input is passed
+ # otherwise model_input_name is None
+ # all model-specific keyword inputs are removed from `model_kwargs`
+ inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
+ inputs, generation_config.bos_token_id, model_kwargs
+ )
+ batch_size = inputs_tensor.shape[0]
+
+ # 4. Define other model kwargs
+ model_kwargs["output_attentions"] = generation_config.output_attentions
+ model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
+ model_kwargs["use_cache"] = generation_config.use_cache
+
+ accepts_attention_mask = "attention_mask" in set(
+ inspect.signature(self.forward).parameters.keys()
+ )
+ requires_attention_mask = "encoder_outputs" not in model_kwargs
+
+ if (
+ model_kwargs.get("attention_mask", None) is None
+ and requires_attention_mask
+ and accepts_attention_mask
+ ):
+ pad_token_tensor = torch.tensor([generation_config.pad_token_id], device=inputs_tensor.device) if generation_config.pad_token_id is not None else None
+ eos_token_tensor = torch.tensor([generation_config.eos_token_id], device=inputs_tensor.device) if generation_config.eos_token_id is not None else None
+
+ model_kwargs[
+ "attention_mask"
+ ] = self._prepare_attention_mask_for_generation(
+ inputs_tensor,
+ pad_token_tensor,
+ eos_token_tensor,
+ )
+
+ # decoder-only models should use left-padding for generation
+ if not self.config.is_encoder_decoder:
+ if (
+ generation_config.pad_token_id is not None
+ and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id)
+ > 0
+ ):
+ logger.warning(
+ "A decoder-only architecture is being used, but right-padding was detected! For correct "
+ "generation results, please set `padding_side='left'` when initializing the tokenizer."
+ )
+
+ if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
+ # if model is encoder decoder encoder_outputs are created
+ # and added to `model_kwargs`
+ model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
+ inputs_tensor, model_kwargs, model_input_name
+ )
+
+ # 5. Prepare `input_ids` which will be used for auto-regressive generation
+ if self.config.is_encoder_decoder:
+ input_ids = self._prepare_decoder_input_ids_for_generation(
+ batch_size,
+ decoder_start_token_id=generation_config.decoder_start_token_id,
+ bos_token_id=generation_config.bos_token_id,
+ model_kwargs=model_kwargs,
+ device=inputs_tensor.device,
+ )
+ else:
+ # if decoder-only then inputs_tensor has to be `input_ids`
+ input_ids = inputs_tensor
+
+ # 6. Prepare `max_length` depending on other stopping criteria.
+ input_ids_seq_length = input_ids.shape[-1]
+ has_default_max_length = (
+ kwargs.get("max_length") is None
+ and generation_config.max_length is not None
+ )
+ if has_default_max_length and generation_config.max_new_tokens is None:
+ warnings.warn(
+ "Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to"
+ f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the"
+ " config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we"
+ " recommend using `max_new_tokens` to control the maximum length of the generation.",
+ UserWarning,
+ )
+ elif has_default_max_length and generation_config.max_new_tokens is not None:
+ generation_config.max_length = (
+ generation_config.max_new_tokens + input_ids_seq_length
+ )
+ elif (
+ not has_default_max_length and generation_config.max_new_tokens is not None
+ ):
+ raise ValueError(
+ "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
+ " limit to the generated output length. Remove one of those arguments. Please refer to the"
+ " documentation for more information. "
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
+ )
+
+ if (
+ generation_config.min_length is not None
+ and generation_config.min_length > generation_config.max_length
+ ):
+ raise ValueError(
+ f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
+ f" the maximum length ({generation_config.max_length})"
+ )
+ if input_ids_seq_length >= generation_config.max_length:
+ input_ids_string = (
+ "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
+ )
+ logger.warning(
+ f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
+ " increasing `max_new_tokens`."
+ )
+
+ if not hasattr(generation_config, "do_stream"):
+ generation_config.do_stream = kwargs.get("do_stream", False)
+
+ # 7. determine generation mode
+ is_constraint_gen_mode = (
+ generation_config.constraints is not None
+ or generation_config.force_words_ids is not None
+ ) and generation_config.do_stream is False
+
+ is_contrastive_search_gen_mode = (
+ generation_config.top_k is not None
+ and generation_config.top_k > 1
+ and generation_config.do_sample is False
+ and generation_config.penalty_alpha is not None
+ and generation_config.penalty_alpha > 0
+ ) and generation_config.do_stream is False
+
+ is_greedy_gen_mode = (
+ (generation_config.num_beams == 1)
+ and (generation_config.num_beam_groups == 1)
+ and generation_config.do_sample is False
+ and not is_constraint_gen_mode
+ and not is_contrastive_search_gen_mode
+ ) and generation_config.do_stream is False
+ is_sample_gen_mode = (
+ (generation_config.num_beams == 1)
+ and (generation_config.num_beam_groups == 1)
+ and generation_config.do_sample is True
+ and generation_config.do_stream is False
+ and not is_constraint_gen_mode
+ and not is_contrastive_search_gen_mode
+ ) and generation_config.do_stream is False
+ is_sample_gen_stream_mode = (
+ (generation_config.num_beams == 1)
+ and (generation_config.num_beam_groups == 1)
+ and generation_config.do_stream is True
+ and not is_constraint_gen_mode
+ and not is_contrastive_search_gen_mode
+ )
+ is_beam_gen_mode = (
+ (generation_config.num_beams > 1)
+ and (generation_config.num_beam_groups == 1)
+ and generation_config.do_sample is False
+ and not is_constraint_gen_mode
+ and not is_contrastive_search_gen_mode
+ ) and generation_config.do_stream is False
+ is_beam_sample_gen_mode = (
+ (generation_config.num_beams > 1)
+ and (generation_config.num_beam_groups == 1)
+ and generation_config.do_sample is True
+ and not is_constraint_gen_mode
+ and not is_contrastive_search_gen_mode
+ ) and generation_config.do_stream is False
+ is_group_beam_gen_mode = (
+ (generation_config.num_beams > 1)
+ and (generation_config.num_beam_groups > 1)
+ and not is_constraint_gen_mode
+ and not is_contrastive_search_gen_mode
+ ) and generation_config.do_stream is False
+
+ if generation_config.num_beam_groups > generation_config.num_beams:
+ raise ValueError(
+ "`num_beam_groups` has to be smaller or equal to `num_beams`"
+ )
+ if is_group_beam_gen_mode and generation_config.do_sample is True:
+ raise ValueError(
+ "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
+ )
+
+ if self.device.type != input_ids.device.type:
+ warnings.warn(
+ "You are calling .generate() with the `input_ids` being on a device type different"
+ f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
+ f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
+ " Please make sure that you have put `input_ids` to the"
+ f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
+ " running `.generate()`.",
+ UserWarning,
+ )
+ # 8. prepare distribution pre_processing samplers
+ logits_processor = self._get_logits_processor(
+ generation_config=generation_config,
+ input_ids_seq_length=input_ids_seq_length,
+ encoder_input_ids=inputs_tensor,
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
+ logits_processor=logits_processor,
+ )
+
+ # 9. prepare stopping criteria
+ stopping_criteria = self._get_stopping_criteria(
+ generation_config=generation_config, stopping_criteria=stopping_criteria
+ )
+ # 10. go into different generation modes
+ if is_greedy_gen_mode:
+ if generation_config.num_return_sequences > 1:
+ raise ValueError(
+ f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
+ " greedy search."
+ )
+
+ # 11. run greedy search
+ return self.greedy_search(
+ input_ids,
+ logits_processor=logits_processor,
+ stopping_criteria=stopping_criteria,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
+ synced_gpus=synced_gpus,
+ **model_kwargs,
+ )
+
+ elif is_contrastive_search_gen_mode:
+ if generation_config.num_return_sequences > 1:
+ raise ValueError(
+ f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
+ " contrastive search."
+ )
+
+ return self.contrastive_search(
+ input_ids,
+ top_k=generation_config.top_k,
+ penalty_alpha=generation_config.penalty_alpha,
+ logits_processor=logits_processor,
+ stopping_criteria=stopping_criteria,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
+ synced_gpus=synced_gpus,
+ **model_kwargs,
+ )
+
+ elif is_sample_gen_mode:
+ # 11. prepare logits warper
+ logits_warper = self._get_logits_warper(generation_config, input_ids.device)
+
+ # 12. expand input_ids with `num_return_sequences` additional sequences per batch
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
+ input_ids=input_ids,
+ expand_size=generation_config.num_return_sequences,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ **model_kwargs,
+ )
+
+ # 13. run sample
+ return self._sample(
+ input_ids,
+ logits_processor=logits_processor,
+ logits_warper=logits_warper,
+ stopping_criteria=stopping_criteria,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
+ synced_gpus=synced_gpus,
+ streamer=None,
+ generation_config=generation_config,
+ **model_kwargs,
+ )
+ elif is_sample_gen_stream_mode:
+ # 11. prepare logits warper
+ logits_warper = self._get_logits_warper(generation_config, input_ids.device)
+
+ # 12. expand input_ids with `num_return_sequences` additional sequences per batch
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
+ input_ids=input_ids,
+ expand_size=generation_config.num_return_sequences,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ **model_kwargs,
+ )
+
+ # 13. run sample
+ return self.sample_stream(
+ input_ids,
+ logits_processor=logits_processor,
+ logits_warper=logits_warper,
+ stopping_criteria=stopping_criteria,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
+ synced_gpus=synced_gpus,
+ **model_kwargs,
+ )
+ elif is_beam_gen_mode:
+ if generation_config.num_return_sequences > generation_config.num_beams:
+ raise ValueError(
+ "`num_return_sequences` has to be smaller or equal to `num_beams`."
+ )
+
+ if stopping_criteria.max_length is None:
+ raise ValueError(
+ "`max_length` needs to be a stopping_criteria for now."
+ )
+
+ # 11. prepare beam search scorer
+ beam_scorer = BeamSearchScorer(
+ batch_size=batch_size,
+ num_beams=generation_config.num_beams,
+ device=inputs_tensor.device,
+ length_penalty=generation_config.length_penalty,
+ do_early_stopping=generation_config.early_stopping,
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
+ )
+ # 12. interleave input_ids with `num_beams` additional sequences per batch
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
+ input_ids=input_ids,
+ expand_size=generation_config.num_beams,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ **model_kwargs,
+ )
+ # 13. run beam search
+ return self.beam_search(
+ input_ids,
+ beam_scorer,
+ logits_processor=logits_processor,
+ stopping_criteria=stopping_criteria,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
+ synced_gpus=synced_gpus,
+ **model_kwargs,
+ )
+
+ elif is_beam_sample_gen_mode:
+ # 11. prepare logits warper
+ logits_warper = self._get_logits_warper(generation_config, input_ids.device)
+
+ if stopping_criteria.max_length is None:
+ raise ValueError(
+ "`max_length` needs to be a stopping_criteria for now."
+ )
+ # 12. prepare beam search scorer
+ beam_scorer = BeamSearchScorer(
+ batch_size=batch_size * generation_config.num_return_sequences,
+ num_beams=generation_config.num_beams,
+ device=inputs_tensor.device,
+ length_penalty=generation_config.length_penalty,
+ do_early_stopping=generation_config.early_stopping,
+ )
+
+ # 13. interleave input_ids with `num_beams` additional sequences per batch
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
+ input_ids=input_ids,
+ expand_size=generation_config.num_beams
+ * generation_config.num_return_sequences,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ **model_kwargs,
+ )
+
+ # 14. run beam sample
+ return self.beam_sample(
+ input_ids,
+ beam_scorer,
+ logits_processor=logits_processor,
+ logits_warper=logits_warper,
+ stopping_criteria=stopping_criteria,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
+ synced_gpus=synced_gpus,
+ **model_kwargs,
+ )
+
+ elif is_group_beam_gen_mode:
+ if generation_config.num_return_sequences > generation_config.num_beams:
+ raise ValueError(
+ "`num_return_sequences` has to be smaller or equal to `num_beams`."
+ )
+
+ if generation_config.num_beams % generation_config.num_beam_groups != 0:
+ raise ValueError(
+ "`num_beams` should be divisible by `num_beam_groups` for group beam search."
+ )
+
+ if stopping_criteria.max_length is None:
+ raise ValueError(
+ "`max_length` needs to be a stopping_criteria for now."
+ )
+
+ has_default_typical_p = (
+ kwargs.get("typical_p") is None and generation_config.typical_p == 1.0
+ )
+ if not has_default_typical_p:
+ raise ValueError(
+ "Decoder argument `typical_p` is not supported with beam groups."
+ )
+
+ # 11. prepare beam search scorer
+ beam_scorer = BeamSearchScorer(
+ batch_size=batch_size,
+ num_beams=generation_config.num_beams,
+ max_length=stopping_criteria.max_length,
+ device=inputs_tensor.device,
+ length_penalty=generation_config.length_penalty,
+ do_early_stopping=generation_config.early_stopping,
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
+ num_beam_groups=generation_config.num_beam_groups,
+ )
+ # 12. interleave input_ids with `num_beams` additional sequences per batch
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
+ input_ids=input_ids,
+ expand_size=generation_config.num_beams,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ **model_kwargs,
+ )
+ # 13. run beam search
+ return self.group_beam_search(
+ input_ids,
+ beam_scorer,
+ logits_processor=logits_processor,
+ stopping_criteria=stopping_criteria,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
+ synced_gpus=synced_gpus,
+ **model_kwargs,
+ )
+
+ elif is_constraint_gen_mode:
+ if generation_config.num_return_sequences > generation_config.num_beams:
+ raise ValueError(
+ "`num_return_sequences` has to be smaller or equal to `num_beams`."
+ )
+
+ if stopping_criteria.max_length is None:
+ raise ValueError(
+ "`max_length` needs to be a stopping_criteria for now."
+ )
+
+ if generation_config.num_beams <= 1:
+ raise ValueError(
+ "`num_beams` needs to be greater than 1 for constrained generation."
+ )
+
+ if generation_config.do_sample:
+ raise ValueError(
+ "`do_sample` needs to be false for constrained generation."
+ )
+
+ if (
+ generation_config.num_beam_groups is not None
+ and generation_config.num_beam_groups > 1
+ ):
+ raise ValueError(
+ "`num_beam_groups` not supported yet for constrained generation."
+ )
+
+ final_constraints = []
+ if generation_config.constraints is not None:
+ final_constraints = generation_config.constraints
+
+ if generation_config.force_words_ids is not None:
+
+ def typeerror():
+ raise ValueError(
+ "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`"
+ f"of positive integers, but is {generation_config.force_words_ids}."
+ )
+
+ if (
+ not isinstance(generation_config.force_words_ids, list)
+ or len(generation_config.force_words_ids) == 0
+ ):
+ typeerror()
+
+ for word_ids in generation_config.force_words_ids:
+ if isinstance(word_ids[0], list):
+ if not isinstance(word_ids, list) or len(word_ids) == 0:
+ typeerror()
+ if any(
+ not isinstance(token_ids, list) for token_ids in word_ids
+ ):
+ typeerror()
+ if any(
+ any(
+ (not isinstance(token_id, int) or token_id < 0)
+ for token_id in token_ids
+ )
+ for token_ids in word_ids
+ ):
+ typeerror()
+
+ constraint = DisjunctiveConstraint(word_ids)
+ else:
+ if not isinstance(word_ids, list) or len(word_ids) == 0:
+ typeerror()
+ if any(
+ (not isinstance(token_id, int) or token_id < 0)
+ for token_id in word_ids
+ ):
+ typeerror()
+
+ constraint = PhrasalConstraint(word_ids)
+ final_constraints.append(constraint)
+
+ # 11. prepare beam search scorer
+ constrained_beam_scorer = ConstrainedBeamSearchScorer(
+ constraints=final_constraints,
+ batch_size=batch_size,
+ num_beams=generation_config.num_beams,
+ device=inputs_tensor.device,
+ length_penalty=generation_config.length_penalty,
+ do_early_stopping=generation_config.early_stopping,
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
+ )
+ # 12. interleave input_ids with `num_beams` additional sequences per batch
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
+ input_ids=input_ids,
+ expand_size=generation_config.num_beams,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ **model_kwargs,
+ )
+ # 13. run beam search
+ return self.constrained_beam_search(
+ input_ids,
+ constrained_beam_scorer=constrained_beam_scorer,
+ logits_processor=logits_processor,
+ stopping_criteria=stopping_criteria,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
+ synced_gpus=synced_gpus,
+ **model_kwargs,
+ )
+
+ @torch.no_grad()
+ def sample_stream(
+ self,
+ input_ids: torch.LongTensor,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ logits_warper: Optional[LogitsProcessorList] = None,
+ max_length: Optional[int] = None,
+ pad_token_id: Optional[int] = None,
+ eos_token_id: Optional[Union[int, List[int]]] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_scores: Optional[bool] = None,
+ return_dict_in_generate: Optional[bool] = None,
+ synced_gpus: Optional[bool] = False,
+ **model_kwargs,
+ ) -> Union[SampleOutput, torch.LongTensor]:
+ r"""
+ Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
+ can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
+
+
+
+ In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead.
+ For an overview of generation strategies and code examples, check the [following
+ guide](./generation_strategies).
+
+
+
+ Parameters:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ The sequence used as a prompt for the generation.
+ logits_processor (`LogitsProcessorList`, *optional*):
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
+ used to modify the prediction scores of the language modeling head applied at each generation step.
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
+ An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
+ used to tell if the generation loop should stop.
+ logits_warper (`LogitsProcessorList`, *optional*):
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
+ to warp the prediction score distribution of the language modeling head applied before multinomial
+ sampling at each generation step.
+ max_length (`int`, *optional*, defaults to 20):
+ **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
+ tokens. The maximum length of the sequence to be generated.
+ pad_token_id (`int`, *optional*):
+ The id of the *padding* token.
+ eos_token_id (`int`, *optional*):
+ The id of the *end-of-sequence* token.
+ output_attentions (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more details.
+ output_hidden_states (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more details.
+ output_scores (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
+ return_dict_in_generate (`bool`, *optional*, defaults to `False`):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ synced_gpus (`bool`, *optional*, defaults to `False`):
+ Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
+ model_kwargs:
+ Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
+ an encoder-decoder model the kwargs should include `encoder_outputs`.
+
+ Return:
+ [`~generation.SampleDecoderOnlyOutput`], [`~generation.SampleEncoderDecoderOutput`] or `torch.LongTensor`:
+ A `torch.LongTensor` containing the generated tokens (default behaviour) or a
+ [`~generation.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
+ `return_dict_in_generate=True` or a [`~generation.SampleEncoderDecoderOutput`] if
+ `model.config.is_encoder_decoder=True`.
+
+ Examples:
+
+ ```python
+ >>> from transformers import (
+ ... AutoTokenizer,
+ ... AutoModelForCausalLM,
+ ... LogitsProcessorList,
+ ... MinLengthLogitsProcessor,
+ ... TopKLogitsWarper,
+ ... TemperatureLogitsWarper,
+ ... StoppingCriteriaList,
+ ... MaxLengthCriteria,
+ ... )
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
+ >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
+
+ >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
+ >>> model.config.pad_token_id = model.config.eos_token_id
+ >>> model.generation_config.pad_token_id = model.config.eos_token_id
+
+ >>> input_prompt = "Today is a beautiful day, and"
+ >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
+
+ >>> # instantiate logits processors
+ >>> logits_processor = LogitsProcessorList(
+ ... [
+ ... MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id),
+ ... ]
+ ... )
+ >>> # instantiate logits processors
+ >>> logits_warper = LogitsProcessorList(
+ ... [
+ ... TopKLogitsWarper(50),
+ ... TemperatureLogitsWarper(0.7),
+ ... ]
+ ... )
+
+ >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
+
+ >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
+ >>> outputs = model.sample(
+ ... input_ids,
+ ... logits_processor=logits_processor,
+ ... logits_warper=logits_warper,
+ ... stopping_criteria=stopping_criteria,
+ ... )
+
+ >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
+ ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the']
+ ```"""
+ # init values
+ logits_processor = (
+ logits_processor if logits_processor is not None else LogitsProcessorList()
+ )
+ stopping_criteria = (
+ stopping_criteria
+ if stopping_criteria is not None
+ else StoppingCriteriaList()
+ )
+ if max_length is not None:
+ warnings.warn(
+ "`max_length` is deprecated in this function, use"
+ " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
+ UserWarning,
+ )
+ stopping_criteria = validate_stopping_criteria(
+ stopping_criteria, max_length
+ )
+ logits_warper = (
+ logits_warper if logits_warper is not None else LogitsProcessorList()
+ )
+ pad_token_id = (
+ pad_token_id
+ if pad_token_id is not None
+ else self.generation_config.pad_token_id
+ )
+ eos_token_id = (
+ eos_token_id
+ if eos_token_id is not None
+ else self.generation_config.eos_token_id
+ )
+ if isinstance(eos_token_id, int):
+ eos_token_id = [eos_token_id]
+ output_scores = (
+ output_scores
+ if output_scores is not None
+ else self.generation_config.output_scores
+ )
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.generation_config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.generation_config.output_hidden_states
+ )
+ return_dict_in_generate = (
+ return_dict_in_generate
+ if return_dict_in_generate is not None
+ else self.generation_config.return_dict_in_generate
+ )
+
+ # init attention / hidden states / scores tuples
+ scores = () if (return_dict_in_generate and output_scores) else None
+ decoder_attentions = (
+ () if (return_dict_in_generate and output_attentions) else None
+ )
+ cross_attentions = (
+ () if (return_dict_in_generate and output_attentions) else None
+ )
+ decoder_hidden_states = (
+ () if (return_dict_in_generate and output_hidden_states) else None
+ )
+
+ # keep track of which sequences are already finished
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
+
+ this_peer_finished = False # used by synced_gpus only
+ # auto-regressive generation
+ while True:
+ if synced_gpus:
+ # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
+ # The following logic allows an early break if all peers finished generating their sequence
+ this_peer_finished_flag = torch.tensor(
+ 0.0 if this_peer_finished else 1.0
+ ).to(input_ids.device)
+ # send 0.0 if we finished, 1.0 otherwise
+ dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
+ # did all peers finish? the reduced sum will be 0.0 then
+ if this_peer_finished_flag.item() == 0.0:
+ break
+
+ # prepare model inputs
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
+
+ # forward pass to get next token
+ outputs = self(
+ **model_inputs,
+ return_dict=True,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ if synced_gpus and this_peer_finished:
+ continue # don't waste resources running the code we don't need
+
+ next_token_logits = outputs.logits[:, -1, :]
+
+ # pre-process distribution
+ next_token_scores = logits_processor(input_ids, next_token_logits)
+ next_token_scores = logits_warper(input_ids, next_token_scores)
+
+ # Store scores, attentions and hidden_states when required
+ if return_dict_in_generate:
+ if output_scores:
+ scores += (next_token_scores,)
+ if output_attentions:
+ decoder_attentions += (
+ (outputs.decoder_attentions,)
+ if self.config.is_encoder_decoder
+ else (outputs.attentions,)
+ )
+ if self.config.is_encoder_decoder:
+ cross_attentions += (outputs.cross_attentions,)
+
+ if output_hidden_states:
+ decoder_hidden_states += (
+ (outputs.decoder_hidden_states,)
+ if self.config.is_encoder_decoder
+ else (outputs.hidden_states,)
+ )
+
+ # sample
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
+
+ # finished sentences should have their next token be a padding token
+ if eos_token_id is not None:
+ if pad_token_id is None:
+ raise ValueError(
+ "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
+ )
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
+ 1 - unfinished_sequences
+ )
+ yield next_tokens, self.final_norm(outputs.hidden_states[-1][:, -1])
+ # update generated ids, model inputs, and length for next step
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
+ model_kwargs = self._update_model_kwargs_for_generation(
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
+ )
+
+ # if eos_token was found in one sentence, set sentence to finished
+ if eos_token_id is not None:
+ unfinished_sequences = unfinished_sequences.mul(
+ (sum(next_tokens != i for i in eos_token_id)).long()
+ )
+
+ # stop when each sentence is finished, or if we exceed the maximum length
+ if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
+ if not synced_gpus:
+ break
+ else:
+ this_peer_finished = True
+
+
+def init_stream_support():
+ PreTrainedModel.generate = NewGenerationMixin.generate
+ PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream
+
+if __name__ == "__main__":
+ from transformers import PreTrainedModel
+ from transformers import AutoTokenizer, AutoModelForCausalLM
+
+ PreTrainedModel.generate = NewGenerationMixin.generate
+ PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream
+ model = AutoModelForCausalLM.from_pretrained(
+ "bigscience/bloom-560m", torch_dtype=torch.float16
+ )
+
+ tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
+ model = model.to("cuda:0")
+ model = model.eval()
+ prompt_text = "hello? \n"
+ input_ids = tokenizer(
+ prompt_text, return_tensors="pt", add_special_tokens=False
+ ).input_ids
+ input_ids = input_ids.to("cuda:0")
+
+ with torch.no_grad():
+ result = model.generate(
+ input_ids,
+ max_new_tokens=200,
+ do_sample=True,
+ top_k=30,
+ top_p=0.85,
+ temperature=0.35,
+ repetition_penalty=1.2,
+ early_stopping=True,
+ seed=0,
+ )
+ print(tokenizer.decode(result, skip_special_tokens=True))
+ generator = model.generate(
+ input_ids,
+ max_new_tokens=200,
+ do_sample=True,
+ top_k=30,
+ top_p=0.85,
+ temperature=0.35,
+ repetition_penalty=1.2,
+ early_stopping=True,
+ seed=0,
+ do_stream=True,
+ )
+ stream_result = ""
+ for x in generator:
+ chunk = tokenizer.decode(x, skip_special_tokens=True)
+ stream_result += chunk
+ print(stream_result)
\ No newline at end of file
diff --git a/tortoise_tts/models/transformer.py b/tortoise_tts/models/transformer.py
index 353a949..b64d4db 100644
--- a/tortoise_tts/models/transformer.py
+++ b/tortoise_tts/models/transformer.py
@@ -1,3 +1,5 @@
+# Copied from https://github.com/neonbjb/tortoise-tts/tree/98a891e66e7a1f11a830f31bd1ce06cc1f6a88af/tortoise/models/transformer.py
+
from functools import partial
import torch
diff --git a/tortoise_tts/models/unified_voice.py b/tortoise_tts/models/unified_voice.py
index c1d6096..f61f41e 100644
--- a/tortoise_tts/models/unified_voice.py
+++ b/tortoise_tts/models/unified_voice.py
@@ -1,3 +1,5 @@
+# Adapted from https://github.com/neonbjb/tortoise-tts/tree/98a891e66e7a1f11a830f31bd1ce06cc1f6a88af/tortoise/models/unified_voice.py
+
import functools
import torch
@@ -14,6 +16,8 @@ from transformers import LogitsWarper
from transformers import GPT2Config, GPT2Model
from tqdm import tqdm
+from .stream_generator import NewGenerationMixin
+
AVAILABLE_ATTENTIONS = ["mem_efficient", "math"]
try:
@@ -83,12 +87,14 @@ class ResBlock(nn.Module):
def forward(self, x):
return F.relu(self.net(x) + x)
-class GPT2InferenceModel(GPT2PreTrainedModel):
+class GPT2InferenceModel(GPT2PreTrainedModel, NewGenerationMixin):
def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache=True):
- super().__init__(config)
+ super(NewGenerationMixin, self).__init__()
+ super(GPT2PreTrainedModel, self).__init__(config)
self.transformer = gpt
self.text_pos_embedding = text_pos_emb
self.embeddings = embeddings
+ self.final_norm = norm
self.lm_head = nn.Sequential(norm, linear)
self.kv_cache = kv_cache
@@ -129,14 +135,14 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
def store_mel_emb(self, mel_emb):
self.cached_mel_emb = mel_emb
- def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
if not self.kv_cache:
- past = None
+ past_key_values = None
# only last token for inputs_ids if past is defined in kwargs
- if past:
+ if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
@@ -148,13 +154,13 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
- if past:
+ if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
"input_ids": input_ids,
- "past_key_values": past,
+ "past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
@@ -597,6 +603,24 @@ class UnifiedVoice(nn.Module):
return loss_text.mean(), loss_mel.mean(), mel_logits
+ def compute_embeddings( self, cond_latents, text_inputs, kv_cache = True ):
+ text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
+ text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token)
+ emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
+ conds = cond_latents.unsqueeze(1)
+ emb = torch.cat([conds, emb], dim=1)
+
+ if not hasattr(self, 'inference_model'):
+ # TODO: Decouple gpt_config from this inference model.
+ self.post_init_gpt2_config(kv_cache = kv_cache)
+
+ self.inference_model.store_mel_emb(emb)
+
+ embs = torch.full( ( emb.shape[0], emb.shape[1] + 1 ), fill_value=1, dtype=torch.long, device=text_inputs.device )
+ embs[:, -1] = self.start_mel_token
+
+ return embs
+
def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1,
max_generate_length=None, typical_sampling=False, typical_mass=.9, kv_cache=True, **hf_generate_kwargs):
@@ -635,6 +659,17 @@ class UnifiedVoice(nn.Module):
self.inference_model.bar.close()
return gen[:, trunc_index:]
+ def get_generator(self, inputs, max_length=500, **hf_generate_kwargs):
+ return self.inference_model.generate(
+ inputs,
+ bos_token_id=self.start_mel_token,
+ pad_token_id=self.stop_mel_token,
+ eos_token_id=self.stop_mel_token,
+ max_length=max_length,
+ do_stream=True,
+ **hf_generate_kwargs,
+ )
+
if __name__ == '__main__':
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4)
diff --git a/tortoise_tts/models/vocoder.py b/tortoise_tts/models/vocoder.py
index 39ec9e0..b825ca6 100644
--- a/tortoise_tts/models/vocoder.py
+++ b/tortoise_tts/models/vocoder.py
@@ -1,3 +1,5 @@
+# Copied from https://github.com/neonbjb/tortoise-tts/tree/98a891e66e7a1f11a830f31bd1ce06cc1f6a88af/tortoise/models/vocoder.py
+
import torch
import torch.nn as nn
import torch.nn.functional as F
diff --git a/tortoise_tts/models/xtransformers.py b/tortoise_tts/models/xtransformers.py
index 4accb20..c75c9e0 100644
--- a/tortoise_tts/models/xtransformers.py
+++ b/tortoise_tts/models/xtransformers.py
@@ -1,3 +1,5 @@
+# Copied from https://github.com/neonbjb/tortoise-tts/tree/98a891e66e7a1f11a830f31bd1ce06cc1f6a88af/tortoise/models/xtransformers.py
+
import math
from collections import namedtuple
from functools import partial
diff --git a/tortoise_tts/webui.py b/tortoise_tts/webui.py
index 4969b40..2324089 100644
--- a/tortoise_tts/webui.py
+++ b/tortoise_tts/webui.py
@@ -95,6 +95,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--beam-width", type=int, default=kwargs["beam-width"])
parser.add_argument("--diffusion-sampler", type=str, default=kwargs["diffusion-sampler"])
parser.add_argument("--cond-free", type=str, default=kwargs["cond-free"])
+ parser.add_argument("--vocoder", type=str, default=kwargs["vocoder"].lower())
"""
parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
parser.add_argument("--mirostat-tau", type=float, default=kwargs["mirostat-tau"])
@@ -126,6 +127,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
beam_width=args.beam_width,
diffusion_sampler=args.diffusion_sampler,
+ vocoder_type=args.vocoder,
)
wav = wav.squeeze(0).cpu().numpy()
@@ -210,7 +212,7 @@ with ui:
with gr.Column(scale=1):
layout["inference"]["inputs"]["reference"] = gr.Audio(label="Audio Input", sources=["upload"], type="filepath") #, info="Reference audio for TTS")
# layout["inference"]["stop"] = gr.Button(value="Stop")
- layout["inference"]["outputs"]["output"] = gr.Audio(label="Output")
+ layout["inference"]["outputs"]["output"] = gr.Audio(label="Output", streaming=True)
layout["inference"]["buttons"]["inference"] = gr.Button(value="Inference")
with gr.Column(scale=7):
with gr.Row():
@@ -221,6 +223,7 @@ with ui:
with gr.Row():
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.8, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
layout["inference"]["inputs"]["diffusion-temp"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (Diffusion)", info="Modifies the initial noise during the diffusion pass.")
+ layout["inference"]["inputs"]["vocoder"] = gr.Radio( ["Vocoder", "BigVGAN", "HiFiGAN"], value="BigVGAN", label="Vocoder", type="value", info="Vocoder to use for generating the final waveform (HiFiGAN skips diffusion)." )
"""
with gr.Row():
layout["inference"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")