added BigVGAN and HiFiGAN (from https://git.ecker.tech/jarod/tortoise-tts), vocoder selectable in webUI

This commit is contained in:
mrq 2024-06-19 21:43:29 -05:00
parent a5c21d65d2
commit 0b1a71430c
19 changed files with 2364 additions and 18 deletions

View File

@ -1,3 +1,5 @@
vocoder: "hifigan"
models: models:
- name: "autoregressive" - name: "autoregressive"
training: True training: True
@ -71,7 +73,7 @@ trainer:
backend: deepspeed backend: deepspeed
deepspeed: deepspeed:
inferencing: True inferencing: False
zero_optimization_level: 0 zero_optimization_level: 0
use_compression_training: False use_compression_training: False

View File

@ -24,6 +24,7 @@ def main():
parser.add_argument("--diffusion-sampler", type=str, default="ddim") parser.add_argument("--diffusion-sampler", type=str, default="ddim")
parser.add_argument("--cond-free", action="store_true") parser.add_argument("--cond-free", action="store_true")
parser.add_argument("--vocoder", type=str, default="bigvgan")
parser.add_argument("--yaml", type=Path, default=None) parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--device", type=str, default=None) parser.add_argument("--device", type=str, default=None)
@ -62,6 +63,8 @@ def main():
diffusion_sampler=args.diffusion_sampler, diffusion_sampler=args.diffusion_sampler,
cond_free=args.cond_free, cond_free=args.cond_free,
vocoder_type=args.vocoder,
) )
""" """
language=args.language, language=args.language,

View File

@ -526,6 +526,8 @@ class Config(BaseConfig):
sample_rate: int = 24_000 sample_rate: int = 24_000
audio_backend: str = "mel" audio_backend: str = "mel"
vocoder: str = "bigvgan" # "vocoder" | "bigvgan" | "hifigan"
@property @property
def model(self): def model(self):
for i, model in enumerate(self.models): for i, model in enumerate(self.models):

View File

@ -5,6 +5,7 @@ import soundfile
from torch import Tensor from torch import Tensor
from einops import rearrange from einops import rearrange
from pathlib import Path from pathlib import Path
from tqdm import tqdm
from .emb.mel import encode_from_files as encode_mel, trim, trim_random from .emb.mel import encode_from_files as encode_mel, trim, trim_random
from .utils import to_device from .utils import to_device
@ -96,6 +97,21 @@ class TTS():
# merge inputs # merge inputs
return encode_mel( paths, device=self.device ) return encode_mel( paths, device=self.device )
# taken from here https://github.com/coqui-ai/TTS/blob/d21f15cc850788f9cdf93dac0321395138665287/TTS/tts/models/xtts.py#L666
def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
"""Handle chunk formatting in streaming mode"""
wav_chunk = wav_gen[:-overlap_len]
if wav_gen_prev is not None:
wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) : -overlap_len]
if wav_overlap is not None:
crossfade_wav = wav_chunk[:overlap_len]
crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device)
wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device)
wav_chunk[:overlap_len] += crossfade_wav
wav_overlap = wav_gen[-overlap_len:]
wav_gen_prev = wav_gen
return wav_chunk, wav_gen_prev, wav_overlap
@torch.inference_mode() @torch.inference_mode()
def inference( def inference(
self, self,
@ -122,7 +138,9 @@ class TTS():
diffusion_sampler="ddim", diffusion_sampler="ddim",
cond_free=True, cond_free=True,
out_path=None vocoder_type="bigvgan",
out_path=None,
): ):
lines = text.split("\n") lines = text.split("\n")
@ -142,7 +160,7 @@ class TTS():
diffusion = engine.module diffusion = engine.module
elif "clvp" in name: elif "clvp" in name:
clvp = engine.module clvp = engine.module
elif "vocoder" in name: elif vocoder_type in name:
vocoder = engine.module vocoder = engine.module
if autoregressive is None: if autoregressive is None:
@ -152,7 +170,7 @@ class TTS():
if clvp is None: if clvp is None:
clvp = load_model("clvp", device=cfg.device) clvp = load_model("clvp", device=cfg.device)
if vocoder is None: if vocoder is None:
vocoder = load_model("vocoder", device=cfg.device) vocoder = load_model(vocoder_type, device=cfg.device)
autoregressive = autoregressive.to(cfg.device) autoregressive = autoregressive.to(cfg.device)
diffusion = diffusion.to(cfg.device) diffusion = diffusion.to(cfg.device)
@ -183,6 +201,88 @@ class TTS():
text_tokens = pad_sequence([ text ], batch_first = True) text_tokens = pad_sequence([ text ], batch_first = True)
text_lengths = torch.Tensor([ text.shape[0] ]).to(dtype=torch.int32) text_lengths = torch.Tensor([ text.shape[0] ]).to(dtype=torch.int32)
# streaming interface spits out the final hidden state, which HiFiGAN seems to be trained against
if vocoder_type == "hifigan":
waves = []
all_latents = []
all_codes = []
wav_gen_prev = None
wav_overlap = None
is_end = False
first_buffer = 60
stream_chunk_size = 40
overlap_wav_len = 1024
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
with ml.auto_unload(autoregressive, enabled=cfg.inference.auto_unload):
with ml.auto_unload(vocoder, enabled=cfg.inference.auto_unload):
inputs = autoregressive.compute_embeddings( autoregressive_latents, text_tokens )
gpt_generator = autoregressive.get_generator(
inputs=inputs,
top_k=top_k,
top_p=top_p,
temperature=ar_temp,
do_sample=True,
num_beams=max(1, beam_width),
num_return_sequences=1,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
output_attentions=False,
output_hidden_states=True,
)
bar = tqdm( unit="it", total=500 )
while not is_end:
try:
codes, latent = next(gpt_generator)
all_latents += [latent]
all_codes += [codes]
except StopIteration:
is_end = True
if is_end or (stream_chunk_size > 0 and len(all_codes) >= max(stream_chunk_size, first_buffer)):
first_buffer = 0
all_codes = []
bar.update( stream_chunk_size )
latents = torch.cat(all_latents, dim=0)[None, :].to(cfg.device)
wav_gen = vocoder.inference(latents, autoregressive_latents)
wav_gen = wav_gen.squeeze()
wav_chunk = wav_gen[:-overlap_wav_len]
if wav_gen_prev is not None:
wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_wav_len) : -overlap_wav_len]
if wav_overlap is not None:
crossfade_wav = wav_chunk[:overlap_wav_len]
crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_wav_len).to(crossfade_wav.device)
wav_chunk[:overlap_wav_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_wav_len).to(wav_overlap.device)
wav_chunk[:overlap_wav_len] += crossfade_wav
wav_overlap = wav_gen[-overlap_wav_len:]
wav_gen_prev = wav_gen
# yielding requires to do a bunch of pain to work around it turning into an async function
"""
yield wav_chunk
"""
waves.append( wav_chunk.unsqueeze(0) )
bar.close()
wav = torch.concat(waves, dim=-1)
if out_path is not None:
torchaudio.save( out_path, wav.cpu(), sr )
wavs.append(wav)
continue
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp): with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
with ml.auto_unload(autoregressive, enabled=cfg.inference.auto_unload): with ml.auto_unload(autoregressive, enabled=cfg.inference.auto_unload):
# autoregressive pass # autoregressive pass
@ -190,9 +290,11 @@ class TTS():
autoregressive_latents, autoregressive_latents,
text_tokens, text_tokens,
do_sample=True, do_sample=True,
top_k=top_k,
top_p=top_p, top_p=top_p,
temperature=ar_temp, temperature=ar_temp,
num_return_sequences=candidates, num_return_sequences=candidates,
num_beams=max(1,beam_width),
length_penalty=length_penalty, length_penalty=length_penalty,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
max_generate_length=max_ar_steps, max_generate_length=max_ar_steps,

View File

@ -1,5 +1,4 @@
# https://github.com/neonbjb/tortoise-tts/tree/98a891e66e7a1f11a830f31bd1ce06cc1f6a88af/tortoise/models # All other ccode in this folder are licensed per the attributions at the top
# All code under this folder is licensed as Apache License 2.0 per the original repo
from functools import cache from functools import cache
@ -8,6 +7,8 @@ from .arch_utils import TorchMelSpectrogram, TacotronSTFT
from .unified_voice import UnifiedVoice from .unified_voice import UnifiedVoice
from .diffusion import DiffusionTTS from .diffusion import DiffusionTTS
from .vocoder import UnivNetGenerator from .vocoder import UnivNetGenerator
from .bigvgan import BigVGAN
from .hifigan import HifiganGenerator
from .clvp import CLVP from .clvp import CLVP
from .dvae import DiscreteVAE from .dvae import DiscreteVAE
from .random_latent_generator import RandomLatentConverter from .random_latent_generator import RandomLatentConverter
@ -15,6 +16,8 @@ from .random_latent_generator import RandomLatentConverter
import os import os
import torch import torch
from pathlib import Path from pathlib import Path
import requests
from tqdm import tqdm
DEFAULT_MODEL_PATH = Path(__file__).parent.parent.parent / 'data/models' DEFAULT_MODEL_PATH = Path(__file__).parent.parent.parent / 'data/models'
DEFAULT_MODEL_URLS = { DEFAULT_MODEL_URLS = {
@ -28,10 +31,18 @@ DEFAULT_MODEL_URLS = {
'rlg_auto.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth', 'rlg_auto.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth',
'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth', 'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth',
'mel_norms.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/data/mel_norms.pth', 'mel_norms.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/data/mel_norms.pth',
# BigVGAN
'bigvgan_base_24khz_100band.pth': 'https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_base_24khz_100band.pth',
'bigvgan_24khz_100band.pth': 'https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_24khz_100band.pth',
'bigvgan_base_24khz_100band.json': 'https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_base_24khz_100band.json',
'bigvgan_24khz_100band.json': 'https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_24khz_100band.json',
# HiFiGAN
'hifigan.pth': 'https://huggingface.co/Manmay/tortoise-tts/resolve/main/hifidecoder.pth',
} }
import requests
from tqdm import tqdm
# kludge, probably better to use HF's model downloader function # kludge, probably better to use HF's model downloader function
# to-do: write to a temp file then copy so downloads can be interrupted # to-do: write to a temp file then copy so downloads can be interrupted
@ -75,6 +86,7 @@ def download_model( save_path, chunkSize = 1024, unit = "MiB" ):
@cache @cache
def load_model(name, device="cuda", **kwargs): def load_model(name, device="cuda", **kwargs):
load_path = None load_path = None
config_path = None
state_dict_key = None state_dict_key = None
strict = True strict = True
@ -95,6 +107,31 @@ def load_model(name, device="cuda", **kwargs):
elif "clvp" in name: elif "clvp" in name:
model = CLVP(**kwargs) model = CLVP(**kwargs)
load_path = DEFAULT_MODEL_PATH / 'clvp2.pth' load_path = DEFAULT_MODEL_PATH / 'clvp2.pth'
elif "bigvgan" in name:
# download any JSONs (BigVGAN)
load_path = DEFAULT_MODEL_PATH / 'bigvgan_24khz_100band.pth'
config_path = load_path.with_suffix(".json")
if config_path.name in DEFAULT_MODEL_URLS:
if not config_path.exists():
download_model( config_path )
else:
config_path = None
model = BigVGAN(config=config_path, **kwargs)
state_dict_key = 'generator'
elif "hifigan" in name:
model = HifiganGenerator(
in_channels=1024,
out_channels = 1,
resblock_type = "1",
resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
resblock_kernel_sizes = [3, 7, 11],
upsample_kernel_sizes = [16, 16, 4, 4],
upsample_initial_channel = 512,
upsample_factors = [8, 8, 2, 2],
cond_channels=1024
)
load_path = DEFAULT_MODEL_PATH / 'hifigan.pth'
elif "vocoder" in name: elif "vocoder" in name:
model = UnivNetGenerator(**kwargs) model = UnivNetGenerator(**kwargs)
load_path = DEFAULT_MODEL_PATH / 'vocoder.pth' load_path = DEFAULT_MODEL_PATH / 'vocoder.pth'
@ -126,6 +163,11 @@ def load_model(name, device="cuda", **kwargs):
model.eval() model.eval()
try:
print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
except Exception as e:
print(f"{name}: {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
return model return model
def unload_model(): def unload_model():
@ -138,8 +180,6 @@ def get_model(config, training=True):
config.training = "autoregressive" in config.name config.training = "autoregressive" in config.name
model.config = config model.config = config
print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
return model return model
def get_models(models, training=True): def get_models(models, training=True):

View File

@ -1,3 +1,5 @@
# Adapted from https://github.com/neonbjb/tortoise-tts/tree/98a891e66e7a1f11a830f31bd1ce06cc1f6a88af/tortoise/models/arch_utils.py
import os import os
import functools import functools
import math import math

View File

@ -0,0 +1,772 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import json
import os
import torch.utils.data
from torch import nn, sin, pow
from torch.nn import Conv1d, ConvTranspose1d, Conv2d, Parameter
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from librosa.filters import mel as librosa_mel_fn
# filter.py
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
# LICENSE is in incl_licenses directory.
if 'sinc' in dir(torch):
sinc = torch.sinc
else:
# This code is adopted from adefossez's julius.core.sinc under the MIT License
# https://adefossez.github.io/julius/julius/core.html
# LICENSE is in incl_licenses directory.
def sinc(x: torch.Tensor):
"""
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
"""
return torch.where(x == 0,
torch.tensor(1., device=x.device, dtype=x.dtype),
torch.sin(math.pi * x) / math.pi / x)
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
# https://adefossez.github.io/julius/julius/lowpass.html
# LICENSE is in incl_licenses directory.
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
even = (kernel_size % 2 == 0)
half_size = kernel_size // 2
#For kaiser window
delta_f = 4 * half_width
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
if A > 50.:
beta = 0.1102 * (A - 8.7)
elif A >= 21.:
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
else:
beta = 0.
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
if even:
time = (torch.arange(-half_size, half_size) + 0.5)
else:
time = torch.arange(kernel_size) - half_size
if cutoff == 0:
filter_ = torch.zeros_like(time)
else:
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
# Normalize filter to have sum = 1, otherwise we will have a small leakage
# of the constant component in the input signal.
filter_ /= filter_.sum()
filter = filter_.view(1, 1, kernel_size)
return filter
class LowPassFilter1d(nn.Module):
def __init__(self,
cutoff=0.5,
half_width=0.6,
stride: int = 1,
padding: bool = True,
padding_mode: str = 'replicate',
kernel_size: int = 12):
# kernel_size should be even number for stylegan3 setup,
# in this implementation, odd number is also possible.
super().__init__()
if cutoff < -0.:
raise ValueError("Minimum cutoff must be larger than zero.")
if cutoff > 0.5:
raise ValueError("A cutoff above 0.5 does not make sense.")
self.kernel_size = kernel_size
self.even = (kernel_size % 2 == 0)
self.pad_left = kernel_size // 2 - int(self.even)
self.pad_right = kernel_size // 2
self.stride = stride
self.padding = padding
self.padding_mode = padding_mode
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
self.register_buffer("filter", filter)
#input [B, C, T]
def forward(self, x):
_, C, _ = x.shape
if self.padding:
x = F.pad(x, (self.pad_left, self.pad_right),
mode=self.padding_mode)
out = F.conv1d(x, self.filter.expand(C, -1, -1),
stride=self.stride, groups=C)
return out
# resample.py
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
# LICENSE is in incl_licenses directory.
class UpSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.stride = ratio
self.pad = self.kernel_size // ratio - 1
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
half_width=0.6 / ratio,
kernel_size=self.kernel_size)
self.register_buffer("filter", filter)
# x: [B, C, T]
def forward(self, x):
_, C, _ = x.shape
x = F.pad(x, (self.pad, self.pad), mode='replicate')
x = self.ratio * F.conv_transpose1d(
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
x = x[..., self.pad_left:-self.pad_right]
return x
class DownSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
half_width=0.6 / ratio,
stride=ratio,
kernel_size=self.kernel_size)
def forward(self, x):
xx = self.lowpass(x)
return xx
# act.py
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
# LICENSE is in incl_licenses directory.
class Activation1d(nn.Module):
def __init__(self,
activation,
up_ratio: int = 2,
down_ratio: int = 2,
up_kernel_size: int = 12,
down_kernel_size: int = 12):
super().__init__()
self.up_ratio = up_ratio
self.down_ratio = down_ratio
self.act = activation
self.upsample = UpSample1d(up_ratio, up_kernel_size)
self.downsample = DownSample1d(down_ratio, down_kernel_size)
# x: [B,C,T]
def forward(self, x):
x = self.upsample(x)
x = self.act(x)
x = self.downsample(x)
return x
# activations.py
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
# LICENSE is in incl_licenses directory.
class Snake(nn.Module):
'''
Implementation of a sine-based periodic activation function
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter
References:
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snake(256)
>>> x = torch.randn(256)
>>> x = a1(x)
'''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
'''
Initialization.
INPUT:
- in_features: shape of the input
- alpha: trainable parameter
alpha is initialized to 1 by default, higher values = higher-frequency.
alpha will be trained along with the rest of your model.
'''
super(Snake, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # log scale alphas initialized to zeros
self.alpha = Parameter(torch.zeros(in_features) * alpha)
else: # linear scale alphas initialized to ones
self.alpha = Parameter(torch.ones(in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
'''
Forward pass of the function.
Applies the function to the input elementwise.
Snake = x + 1/a * sin^2 (xa)
'''
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
if self.alpha_logscale:
alpha = torch.exp(alpha)
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
return x
class SnakeBeta(nn.Module):
'''
A modified Snake function which uses separate parameters for the magnitude of the periodic components
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
References:
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snakebeta(256)
>>> x = torch.randn(256)
>>> x = a1(x)
'''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
'''
Initialization.
INPUT:
- in_features: shape of the input
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
alpha is initialized to 1 by default, higher values = higher-frequency.
beta is initialized to 1 by default, higher values = higher-magnitude.
alpha will be trained along with the rest of your model.
'''
super(SnakeBeta, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # log scale alphas initialized to zeros
self.alpha = Parameter(torch.zeros(in_features) * alpha)
self.beta = Parameter(torch.zeros(in_features) * alpha)
else: # linear scale alphas initialized to ones
self.alpha = Parameter(torch.ones(in_features) * alpha)
self.beta = Parameter(torch.ones(in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.beta.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
'''
Forward pass of the function.
Applies the function to the input elementwise.
SnakeBeta = x + 1/b * sin^2 (xa)
'''
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
beta = self.beta.unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
alpha = torch.exp(alpha)
beta = torch.exp(beta)
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
return x
# bigvgan.py
# Copyright (c) 2022 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
LRELU_SLOPE = 0.1
class AMPBlock1(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
super(AMPBlock1, self).__init__()
self.h = h
self.convs1 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2])))
])
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1)))
])
self.convs2.apply(init_weights)
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=Snake(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=SnakeBeta(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
else:
raise NotImplementedError(
"activation incorrectly specified. check the config file and look for 'activation'.")
def forward(self, x):
acts1, acts2 = self.activations[::2], self.activations[1::2]
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
xt = a1(x)
xt = c1(xt)
xt = a2(xt)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class AMPBlock2(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None):
super(AMPBlock2, self).__init__()
self.h = h
self.convs = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1])))
])
self.convs.apply(init_weights)
self.num_layers = len(self.convs) # total number of conv layers
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=Snake(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=SnakeBeta(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
else:
raise NotImplementedError(
"activation incorrectly specified. check the config file and look for 'activation'.")
def forward(self, x):
for c, a in zip(self.convs, self.activations):
xt = a(x)
xt = c(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
class BigVGAN(nn.Module):
# this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
def __init__(self, config=None, data=None):
super(BigVGAN, self).__init__()
"""
with open(os.path.join(os.path.dirname(__file__), 'config.json'), 'r') as f:
data = f.read()
"""
if config and data is None:
with open(config, 'r') as f:
data = f.read()
jsonConfig = json.loads(data)
elif data is not None:
if isinstance(data, str):
jsonConfig = json.loads(data)
else:
jsonConfig = data
else:
raise Exception("no config specified")
global h
h = AttrDict(jsonConfig)
self.mel_channel = h.num_mels
self.noise_dim = h.n_fft
self.hop_length = h.hop_size
self.num_kernels = len(h.resblock_kernel_sizes)
self.num_upsamples = len(h.upsample_rates)
# pre conv
self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
# define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2
# transposed conv-based upsamplers. does not apply anti-aliasing
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
self.ups.append(nn.ModuleList([
weight_norm(ConvTranspose1d(h.upsample_initial_channel // (2 ** i),
h.upsample_initial_channel // (2 ** (i + 1)),
k, u, padding=(k - u) // 2))
]))
# residual blocks using anti-aliased multi-periodicity composition modules (AMP)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
# post conv
if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
activation_post = Snake(ch, alpha_logscale=h.snake_logscale)
self.activation_post = Activation1d(activation=activation_post)
elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
activation_post = SnakeBeta(ch, alpha_logscale=h.snake_logscale)
self.activation_post = Activation1d(activation=activation_post)
else:
raise NotImplementedError(
"activation incorrectly specified. check the config file and look for 'activation'.")
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
# weight initialization
for i in range(len(self.ups)):
self.ups[i].apply(init_weights)
self.conv_post.apply(init_weights)
def forward(self,x, c):
# pre conv
x = self.conv_pre(x)
for i in range(self.num_upsamples):
# upsampling
for i_up in range(len(self.ups[i])):
x = self.ups[i][i_up](x)
# AMP blocks
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
# post conv
x = self.activation_post(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.ups:
for l_i in l:
remove_weight_norm(l_i)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
def inference(self, c, z=None):
# pad input mel with zeros to cut artifact
# see https://github.com/seungwonpark/melgan/issues/8
zero = torch.full((c.shape[0], h.num_mels, 10), -11.5129).to(c.device)
mel = torch.cat((c, zero), dim=2)
if z is None:
z = torch.randn(c.shape[0], self.noise_dim, mel.size(2)).to(mel.device)
audio = self.forward(mel, z)
audio = audio[:, :, :-(self.hop_length * 10)]
audio = audio.clamp(min=-1, max=1)
return audio
def eval(self, inference=False):
super(BigVGAN, self).eval()
# don't remove weight norm while validation in training loop
if inference:
self.remove_weight_norm()
class DiscriminatorP(nn.Module):
def __init__(self, h, period, kernel_size=5, stride=3, use_spectral_norm=False):
super(DiscriminatorP, self).__init__()
self.period = period
self.d_mult = h.discriminator_channel_mult
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList([
norm_f(Conv2d(1, int(32 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(int(32 * self.d_mult), int(128 * self.d_mult), (kernel_size, 1), (stride, 1),
padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(int(128 * self.d_mult), int(512 * self.d_mult), (kernel_size, 1), (stride, 1),
padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(int(512 * self.d_mult), int(1024 * self.d_mult), (kernel_size, 1), (stride, 1),
padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(int(1024 * self.d_mult), int(1024 * self.d_mult), (kernel_size, 1), 1, padding=(2, 0))),
])
self.conv_post = norm_f(Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(nn.Module):
def __init__(self, h):
super(MultiPeriodDiscriminator, self).__init__()
self.mpd_reshapes = h.mpd_reshapes
print("mpd_reshapes: {}".format(self.mpd_reshapes))
discriminators = [DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm) for rs in self.mpd_reshapes]
self.discriminators = nn.ModuleList(discriminators)
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class DiscriminatorR(nn.Module):
def __init__(self, cfg, resolution):
super().__init__()
self.resolution = resolution
assert len(self.resolution) == 3, \
"MRD layer requires list with len=3, got {}".format(self.resolution)
self.lrelu_slope = LRELU_SLOPE
norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm
if hasattr(cfg, "mrd_use_spectral_norm"):
print("INFO: overriding MRD use_spectral_norm as {}".format(cfg.mrd_use_spectral_norm))
norm_f = weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
self.d_mult = cfg.discriminator_channel_mult
if hasattr(cfg, "mrd_channel_mult"):
print("INFO: overriding mrd channel multiplier as {}".format(cfg.mrd_channel_mult))
self.d_mult = cfg.mrd_channel_mult
self.convs = nn.ModuleList([
norm_f(nn.Conv2d(1, int(32 * self.d_mult), (3, 9), padding=(1, 4))),
norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 3), padding=(1, 1))),
])
self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1)))
def forward(self, x):
fmap = []
x = self.spectrogram(x)
x = x.unsqueeze(1)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, self.lrelu_slope)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
def spectrogram(self, x):
n_fft, hop_length, win_length = self.resolution
x = F.pad(x, (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), mode='reflect')
x = x.squeeze(1)
x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False, return_complex=True)
x = torch.view_as_real(x) # [B, F, TT, 2]
mag = torch.norm(x, p=2, dim=-1) # [B, F, TT]
return mag
class MultiResolutionDiscriminator(nn.Module):
def __init__(self, cfg, debug=False):
super().__init__()
self.resolutions = cfg.resolutions
assert len(self.resolutions) == 3, \
"MRD requires list of list with len=3, each element having a list with len=3. got {}". \
format(self.resolutions)
self.discriminators = nn.ModuleList(
[DiscriminatorR(cfg, resolution) for resolution in self.resolutions]
)
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(x=y)
y_d_g, fmap_g = d(x=y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
def get_mel(x):
return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
if torch.min(y) < -1.:
print('min value is ', torch.min(y))
if torch.max(y) > 1.:
print('max value is ', torch.max(y))
global mel_basis, hann_window
if fmax not in mel_basis:
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
y = y.squeeze(1)
# complex tensor as default, then use view_as_real for future pytorch compatibility
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
spec = torch.view_as_real(spec)
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
spec = torch.nn.utils.spectral_normalize_torch(spec)
return spec
def feature_loss(fmap_r, fmap_g):
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
loss += torch.mean(torch.abs(rl - gl))
return loss * 2
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
r_loss = torch.mean((1 - dr) ** 2)
g_loss = torch.mean(dg ** 2)
loss += (r_loss + g_loss)
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
return loss, r_losses, g_losses
def generator_loss(disc_outputs):
loss = 0
gen_losses = []
for dg in disc_outputs:
l = torch.mean((1 - dg) ** 2)
gen_losses.append(l)
loss += l
return loss, gen_losses
if __name__ == '__main__':
model = BigVGAN()
c = torch.randn(3, 100, 10)
z = torch.randn(3, 64, 10)
print(c.shape)
y = model(c, z)
print(y.shape)
assert y.shape == torch.Size([3, 1, 2560])
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(pytorch_total_params)

View File

@ -1,3 +1,5 @@
# Copied from https://github.com/neonbjb/tortoise-tts/tree/98a891e66e7a1f11a830f31bd1ce06cc1f6a88af/tortoise/models/classifier.py
import torch import torch
import torch.nn as nn import torch.nn as nn

View File

@ -1,3 +1,5 @@
# Copied from https://github.com/neonbjb/tortoise-tts/tree/98a891e66e7a1f11a830f31bd1ce06cc1f6a88af/tortoise/models/clvp.py
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F

View File

@ -1,3 +1,5 @@
# Adapted from https://github.com/neonbjb/tortoise-tts/tree/98a891e66e7a1f11a830f31bd1ce06cc1f6a88af/tortoise/models/diffusion.py
import enum import enum
import math import math
import random import random

View File

@ -0,0 +1,305 @@
# Grabbed from https://git.ecker.tech/Jarod/tortoise-tts/src/branch/main
# Adapted from https://github.com/jik876/hifi-gan/blob/master/models.py
import torch
from torch import nn
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm
LRELU_SLOPE = 0.1
def get_padding(k, d):
return int((k * d - d) / 2)
class ResBlock1(torch.nn.Module):
"""Residual Block Type 1. It has 3 convolutional layers in each convolutional block.
Network::
x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o
|--------------------------------------------------------------------------------------------------|
Args:
channels (int): number of hidden channels for the convolutional layers.
kernel_size (int): size of the convolution filter in each layer.
dilations (list): list of dilation value for each conv layer in a block.
"""
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super().__init__()
self.convs1 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs2 = nn.ModuleList(
[
weight_norm(
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
),
weight_norm(
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
),
weight_norm(
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
),
]
)
def forward(self, x):
"""
Args:
x (Tensor): input tensor.
Returns:
Tensor: output tensor.
Shapes:
x: [B, C, T]
"""
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c1(xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class ResBlock2(torch.nn.Module):
"""Residual Block Type 2. It has 1 convolutional layers in each convolutional block.
Network::
x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o
|---------------------------------------------------|
Args:
channels (int): number of hidden channels for the convolutional layers.
kernel_size (int): size of the convolution filter in each layer.
dilations (list): list of dilation value for each conv layer in a block.
"""
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
super().__init__()
self.convs = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
]
)
def forward(self, x):
for c in self.convs:
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
class HifiganGenerator(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
resblock_type,
resblock_dilation_sizes,
resblock_kernel_sizes,
upsample_kernel_sizes,
upsample_initial_channel,
upsample_factors,
inference_padding=5,
cond_channels=0,
conv_pre_weight_norm=True,
conv_post_weight_norm=True,
conv_post_bias=True,
):
r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF)
Network:
x -> lrelu -> upsampling_layer -> resblock1_k1x1 -> z1 -> + -> z_sum / #resblocks -> lrelu -> conv_post_7x1 -> tanh -> o
.. -> zI ---|
resblockN_kNx1 -> zN ---'
Args:
in_channels (int): number of input tensor channels.
out_channels (int): number of output tensor channels.
resblock_type (str): type of the `ResBlock`. '1' or '2'.
resblock_dilation_sizes (List[List[int]]): list of dilation values in each layer of a `ResBlock`.
resblock_kernel_sizes (List[int]): list of kernel sizes for each `ResBlock`.
upsample_kernel_sizes (List[int]): list of kernel sizes for each transposed convolution.
upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2
for each consecutive upsampling layer.
upsample_factors (List[int]): upsampling factors (stride) for each upsampling layer.
inference_padding (int): constant padding applied to the input at inference time. Defaults to 5.
"""
super().__init__()
self.inference_padding = inference_padding
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_factors)
# initial upsampling layers
self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3))
resblock = ResBlock1 if resblock_type == "1" else ResBlock2
# upsampling layers
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)):
self.ups.append(
weight_norm(
ConvTranspose1d(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
# MRF blocks
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d))
# post convolution layer
self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias))
if cond_channels > 0:
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
if not conv_pre_weight_norm:
remove_weight_norm(self.conv_pre)
if not conv_post_weight_norm:
remove_weight_norm(self.conv_post)
self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu')
if torch.backends.mps.is_available():
self.device = torch.device('mps')
def forward(self, x, g=None):
"""
Args:
x (Tensor): feature input tensor.
g (Tensor): global conditioning input tensor.
Returns:
Tensor: output waveform.
Shapes:
x: [B, C, T]
Tensor: [B, 1, T]
"""
o = self.conv_pre(x)
if hasattr(self, "cond_layer"):
o = o + self.cond_layer(g)
for i in range(self.num_upsamples):
o = F.leaky_relu(o, LRELU_SLOPE)
o = self.ups[i](o)
z_sum = None
for j in range(self.num_kernels):
if z_sum is None:
z_sum = self.resblocks[i * self.num_kernels + j](o)
else:
z_sum += self.resblocks[i * self.num_kernels + j](o)
o = z_sum / self.num_kernels
o = F.leaky_relu(o)
o = self.conv_post(o)
o = torch.tanh(o)
return o
@torch.no_grad()
def inference(self, c, g=None):
"""
Args:
x (Tensor): conditioning input tensor.
Returns:
Tensor: output waveform.
Shapes:
x: [B, C, T]
Tensor: [B, 1, T]
"""
# c = c.to(self.conv_pre.weight.device)
# c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
up_1 = torch.nn.functional.interpolate(
c.transpose(1,2),
scale_factor=[1024 / 256],
mode="linear",
)
up_2 = torch.nn.functional.interpolate(
up_1,
scale_factor=[24000 / 22050],
mode="linear",
)
g = g.unsqueeze(0)
return self.forward(up_2.to(self.device), g.transpose(1,2))
def remove_weight_norm(self):
print("Removing weight norm...")
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)

View File

@ -1,4 +1,5 @@
# Adapted from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py # Adapted from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
from functools import partial from functools import partial
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F

View File

@ -1,3 +1,6 @@
# Copied from https://github.com/neonbjb/tortoise-tts/tree/98a891e66e7a1f11a830f31bd1ce06cc1f6a88af/tortoise/models/random_latent_generator.py
import math import math
import torch import torch

File diff suppressed because it is too large Load Diff

View File

@ -1,3 +1,5 @@
# Copied from https://github.com/neonbjb/tortoise-tts/tree/98a891e66e7a1f11a830f31bd1ce06cc1f6a88af/tortoise/models/transformer.py
from functools import partial from functools import partial
import torch import torch

View File

@ -1,3 +1,5 @@
# Adapted from https://github.com/neonbjb/tortoise-tts/tree/98a891e66e7a1f11a830f31bd1ce06cc1f6a88af/tortoise/models/unified_voice.py
import functools import functools
import torch import torch
@ -14,6 +16,8 @@ from transformers import LogitsWarper
from transformers import GPT2Config, GPT2Model from transformers import GPT2Config, GPT2Model
from tqdm import tqdm from tqdm import tqdm
from .stream_generator import NewGenerationMixin
AVAILABLE_ATTENTIONS = ["mem_efficient", "math"] AVAILABLE_ATTENTIONS = ["mem_efficient", "math"]
try: try:
@ -83,12 +87,14 @@ class ResBlock(nn.Module):
def forward(self, x): def forward(self, x):
return F.relu(self.net(x) + x) return F.relu(self.net(x) + x)
class GPT2InferenceModel(GPT2PreTrainedModel): class GPT2InferenceModel(GPT2PreTrainedModel, NewGenerationMixin):
def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache=True): def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache=True):
super().__init__(config) super(NewGenerationMixin, self).__init__()
super(GPT2PreTrainedModel, self).__init__(config)
self.transformer = gpt self.transformer = gpt
self.text_pos_embedding = text_pos_emb self.text_pos_embedding = text_pos_emb
self.embeddings = embeddings self.embeddings = embeddings
self.final_norm = norm
self.lm_head = nn.Sequential(norm, linear) self.lm_head = nn.Sequential(norm, linear)
self.kv_cache = kv_cache self.kv_cache = kv_cache
@ -129,14 +135,14 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
def store_mel_emb(self, mel_emb): def store_mel_emb(self, mel_emb):
self.cached_mel_emb = mel_emb self.cached_mel_emb = mel_emb
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None) token_type_ids = kwargs.get("token_type_ids", None)
if not self.kv_cache: if not self.kv_cache:
past = None past_key_values = None
# only last token for inputs_ids if past is defined in kwargs # only last token for inputs_ids if past is defined in kwargs
if past: if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1) input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None: if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1) token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
@ -148,13 +154,13 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
# create position_ids on the fly for batch generation # create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1 position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1) position_ids.masked_fill_(attention_mask == 0, 1)
if past: if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1) position_ids = position_ids[:, -1].unsqueeze(-1)
else: else:
position_ids = None position_ids = None
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"past_key_values": past, "past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"), "use_cache": kwargs.get("use_cache"),
"position_ids": position_ids, "position_ids": position_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
@ -597,6 +603,24 @@ class UnifiedVoice(nn.Module):
return loss_text.mean(), loss_mel.mean(), mel_logits return loss_text.mean(), loss_mel.mean(), mel_logits
def compute_embeddings( self, cond_latents, text_inputs, kv_cache = True ):
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token)
emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
conds = cond_latents.unsqueeze(1)
emb = torch.cat([conds, emb], dim=1)
if not hasattr(self, 'inference_model'):
# TODO: Decouple gpt_config from this inference model.
self.post_init_gpt2_config(kv_cache = kv_cache)
self.inference_model.store_mel_emb(emb)
embs = torch.full( ( emb.shape[0], emb.shape[1] + 1 ), fill_value=1, dtype=torch.long, device=text_inputs.device )
embs[:, -1] = self.start_mel_token
return embs
def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1, def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1,
max_generate_length=None, typical_sampling=False, typical_mass=.9, kv_cache=True, **hf_generate_kwargs): max_generate_length=None, typical_sampling=False, typical_mass=.9, kv_cache=True, **hf_generate_kwargs):
@ -635,6 +659,17 @@ class UnifiedVoice(nn.Module):
self.inference_model.bar.close() self.inference_model.bar.close()
return gen[:, trunc_index:] return gen[:, trunc_index:]
def get_generator(self, inputs, max_length=500, **hf_generate_kwargs):
return self.inference_model.generate(
inputs,
bos_token_id=self.start_mel_token,
pad_token_id=self.stop_mel_token,
eos_token_id=self.stop_mel_token,
max_length=max_length,
do_stream=True,
**hf_generate_kwargs,
)
if __name__ == '__main__': if __name__ == '__main__':
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4) gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4)

View File

@ -1,3 +1,5 @@
# Copied from https://github.com/neonbjb/tortoise-tts/tree/98a891e66e7a1f11a830f31bd1ce06cc1f6a88af/tortoise/models/vocoder.py
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F

View File

@ -1,3 +1,5 @@
# Copied from https://github.com/neonbjb/tortoise-tts/tree/98a891e66e7a1f11a830f31bd1ce06cc1f6a88af/tortoise/models/xtransformers.py
import math import math
from collections import namedtuple from collections import namedtuple
from functools import partial from functools import partial

View File

@ -95,6 +95,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--beam-width", type=int, default=kwargs["beam-width"]) parser.add_argument("--beam-width", type=int, default=kwargs["beam-width"])
parser.add_argument("--diffusion-sampler", type=str, default=kwargs["diffusion-sampler"]) parser.add_argument("--diffusion-sampler", type=str, default=kwargs["diffusion-sampler"])
parser.add_argument("--cond-free", type=str, default=kwargs["cond-free"]) parser.add_argument("--cond-free", type=str, default=kwargs["cond-free"])
parser.add_argument("--vocoder", type=str, default=kwargs["vocoder"].lower())
""" """
parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"]) parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
parser.add_argument("--mirostat-tau", type=float, default=kwargs["mirostat-tau"]) parser.add_argument("--mirostat-tau", type=float, default=kwargs["mirostat-tau"])
@ -126,6 +127,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
beam_width=args.beam_width, beam_width=args.beam_width,
diffusion_sampler=args.diffusion_sampler, diffusion_sampler=args.diffusion_sampler,
vocoder_type=args.vocoder,
) )
wav = wav.squeeze(0).cpu().numpy() wav = wav.squeeze(0).cpu().numpy()
@ -210,7 +212,7 @@ with ui:
with gr.Column(scale=1): with gr.Column(scale=1):
layout["inference"]["inputs"]["reference"] = gr.Audio(label="Audio Input", sources=["upload"], type="filepath") #, info="Reference audio for TTS") layout["inference"]["inputs"]["reference"] = gr.Audio(label="Audio Input", sources=["upload"], type="filepath") #, info="Reference audio for TTS")
# layout["inference"]["stop"] = gr.Button(value="Stop") # layout["inference"]["stop"] = gr.Button(value="Stop")
layout["inference"]["outputs"]["output"] = gr.Audio(label="Output") layout["inference"]["outputs"]["output"] = gr.Audio(label="Output", streaming=True)
layout["inference"]["buttons"]["inference"] = gr.Button(value="Inference") layout["inference"]["buttons"]["inference"] = gr.Button(value="Inference")
with gr.Column(scale=7): with gr.Column(scale=7):
with gr.Row(): with gr.Row():
@ -221,6 +223,7 @@ with ui:
with gr.Row(): with gr.Row():
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.8, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)") layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.8, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
layout["inference"]["inputs"]["diffusion-temp"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (Diffusion)", info="Modifies the initial noise during the diffusion pass.") layout["inference"]["inputs"]["diffusion-temp"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (Diffusion)", info="Modifies the initial noise during the diffusion pass.")
layout["inference"]["inputs"]["vocoder"] = gr.Radio( ["Vocoder", "BigVGAN", "HiFiGAN"], value="BigVGAN", label="Vocoder", type="value", info="Vocoder to use for generating the final waveform (HiFiGAN skips diffusion)." )
""" """
with gr.Row(): with gr.Row():
layout["inference"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.") layout["inference"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")