applied the bitsandbytes wrapper to tortoise inference (not sure if it matters)

remotes/1710189933836426429/master
mrq 2023-02-28 01:42:10 +07:00
parent 7cc0250a1a
commit 7b839a4263
11 changed files with 167 additions and 41 deletions

@ -221,6 +221,16 @@ class TextToSpeech:
if device is None:
device = get_device(verbose=True)
try:
import tortoise.utils.torch_intermediary as ml
if ml.OVERRIDE_ADAM:
print("Using BitsAndBytes ADAMW optimizations")
else:
print("NOT using BitsAndBytes ADAMW optimizations")
except Exception as e:
print(e)
pass
self.input_sample_rate = input_sample_rate
self.output_sample_rate = output_sample_rate
self.minor_optimizations = minor_optimizations

@ -11,6 +11,7 @@ from tortoise.utils.typical_sampling import TypicalLogitsWarper
from tortoise.utils.device import get_device_count
import tortoise.utils.torch_intermediary as ml
def null_position_embeddings(range, dim):
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
@ -221,7 +222,8 @@ class ConditioningEncoder(nn.Module):
class LearnedPositionEmbeddings(nn.Module):
def __init__(self, seq_len, model_dim, init=.02):
super().__init__()
self.emb = nn.Embedding(seq_len, model_dim)
# ml.Embedding
self.emb = ml.Embedding(seq_len, model_dim)
# Initializing this way is standard for GPT-2
self.emb.weight.data.normal_(mean=0.0, std=init)
@ -321,9 +323,11 @@ class UnifiedVoice(nn.Module):
self.max_conditioning_inputs = max_conditioning_inputs
self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim)
# ml.Embedding
self.text_embedding = ml.Embedding(self.number_text_tokens*types+1, model_dim)
if use_mel_codes_as_input:
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
# ml.Embedding
self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim)
else:
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
@ -336,8 +340,10 @@ class UnifiedVoice(nn.Module):
self.text_solo_embedding = 0
self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1)
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
# nn.Linear
self.text_head = ml.Linear(model_dim, self.number_text_tokens*types+1)
# nn.Linear
self.mel_head = ml.Linear(model_dim, self.number_mel_codes)
# Initialize the embeddings per the GPT-2 scheme
embeddings = [self.text_embedding]

@ -3,6 +3,7 @@ import torch.nn as nn
from tortoise.models.arch_util import Upsample, Downsample, normalization, zero_module, AttentionBlock
import tortoise.utils.torch_intermediary as ml
class ResBlock(nn.Module):
def __init__(
@ -124,7 +125,8 @@ class AudioMiniEncoderWithClassifierHead(nn.Module):
def __init__(self, classes, distribute_zero_label=True, **kwargs):
super().__init__()
self.enc = AudioMiniEncoder(**kwargs)
self.head = nn.Linear(self.enc.dim, classes)
# nn.Linear
self.head = ml.Linear(self.enc.dim, classes)
self.num_classes = classes
self.distribute_zero_label = distribute_zero_label

@ -7,6 +7,7 @@ from tortoise.models.arch_util import CheckpointedXTransformerEncoder
from tortoise.models.transformer import Transformer
from tortoise.models.xtransformers import Encoder
import tortoise.utils.torch_intermediary as ml
def exists(val):
return val is not None
@ -44,11 +45,15 @@ class CLVP(nn.Module):
use_xformers=False,
):
super().__init__()
self.text_emb = nn.Embedding(num_text_tokens, dim_text)
self.to_text_latent = nn.Linear(dim_text, dim_latent, bias=False)
# nn.Embedding
self.text_emb = ml.Embedding(num_text_tokens, dim_text)
# nn.Linear
self.to_text_latent = ml.Linear(dim_text, dim_latent, bias=False)
self.speech_emb = nn.Embedding(num_speech_tokens, dim_speech)
self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False)
# nn.Embedding
self.speech_emb = ml.Embedding(num_speech_tokens, dim_speech)
# nn.Linear
self.to_speech_latent = ml.Linear(dim_speech, dim_latent, bias=False)
if use_xformers:
self.text_transformer = CheckpointedXTransformerEncoder(
@ -93,8 +98,10 @@ class CLVP(nn.Module):
self.wav_token_compression = wav_token_compression
self.xformers = use_xformers
if not use_xformers:
self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
# nn.Embedding
self.text_pos_emb = ml.Embedding(text_seq_len, dim_text)
# nn.Embedding
self.speech_pos_emb = ml.Embedding(num_speech_tokens, dim_speech)
def forward(
self,

@ -6,6 +6,7 @@ from torch import einsum
from tortoise.models.arch_util import AttentionBlock
from tortoise.models.xtransformers import ContinuousTransformerWrapper, Encoder
import tortoise.utils.torch_intermediary as ml
def exists(val):
return val is not None
@ -54,7 +55,8 @@ class CollapsingTransformer(nn.Module):
class ConvFormatEmbedding(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.emb = nn.Embedding(*args, **kwargs)
# nn.Embedding
self.emb = ml.Embedding(*args, **kwargs)
def forward(self, x):
y = self.emb(x)
@ -83,7 +85,8 @@ class CVVP(nn.Module):
nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1))
self.conditioning_transformer = CollapsingTransformer(
model_dim, model_dim, transformer_heads, dropout, conditioning_enc_depth, cond_mask_percentage)
self.to_conditioning_latent = nn.Linear(
# nn.Linear
self.to_conditioning_latent = ml.Linear(
latent_dim, latent_dim, bias=False)
if mel_codes is None:
@ -93,7 +96,8 @@ class CVVP(nn.Module):
self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim)
self.speech_transformer = CollapsingTransformer(
model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage)
self.to_speech_latent = nn.Linear(
# nn.Linear
self.to_speech_latent = ml.Linear(
latent_dim, latent_dim, bias=False)
def get_grad_norm_parameter_groups(self):

@ -10,6 +10,8 @@ from torch import autocast
from tortoise.models.arch_util import normalization, AttentionBlock
from tortoise.utils.device import get_device_name
import tortoise.utils.torch_intermediary as ml
def is_latent(t):
return t.dtype == torch.float
@ -87,7 +89,8 @@ class ResBlock(TimestepBlock):
self.emb_layers = nn.Sequential(
nn.SiLU(),
nn.Linear(
# nn.Linear
ml.Linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
),
@ -160,16 +163,19 @@ class DiffusionTts(nn.Module):
self.inp_block = nn.Conv1d(in_channels, model_channels, 3, 1, 1)
self.time_embed = nn.Sequential(
nn.Linear(model_channels, model_channels),
# nn.Linear
ml.Linear(model_channels, model_channels),
nn.SiLU(),
nn.Linear(model_channels, model_channels),
# nn.Linear
ml.Linear(model_channels, model_channels),
)
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
# transformer network.
self.code_embedding = nn.Embedding(in_tokens, model_channels)
# nn.Embedding
self.code_embedding = ml.Embedding(in_tokens, model_channels)
self.code_converter = nn.Sequential(
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),

@ -4,6 +4,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import tortoise.utils.torch_intermediary as ml
def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
if bias is not None:
@ -41,7 +42,8 @@ class RandomLatentConverter(nn.Module):
def __init__(self, channels):
super().__init__()
self.layers = nn.Sequential(*[EqualLinear(channels, channels, lr_mul=.1) for _ in range(5)],
nn.Linear(channels, channels))
# nn.Linear
ml.Linear(channels, channels))
self.channels = channels
def forward(self, ref):

@ -6,6 +6,7 @@ from einops import rearrange
from rotary_embedding_torch import RotaryEmbedding, broadcat
from torch import nn
import tortoise.utils.torch_intermediary as ml
# helpers
@ -120,10 +121,12 @@ class FeedForward(nn.Module):
def __init__(self, dim, dropout = 0., mult = 4.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult * 2),
# nn.Linear
ml.Linear(dim, dim * mult * 2),
GEGLU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim)
# nn.Linear
ml.Linear(dim * mult, dim)
)
def forward(self, x):
@ -142,9 +145,11 @@ class Attention(nn.Module):
self.causal = causal
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
# nn.Linear
self.to_qkv = ml.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
# nn.Linear
ml.Linear(inner_dim, dim),
nn.Dropout(dropout)
)

@ -8,6 +8,8 @@ import torch.nn.functional as F
from einops import rearrange, repeat
from torch import nn, einsum
import tortoise.utils.torch_intermediary as ml
DEFAULT_DIM_HEAD = 64
Intermediates = namedtuple('Intermediates', [
@ -121,7 +123,8 @@ class AbsolutePositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len):
super().__init__()
self.scale = dim ** -0.5
self.emb = nn.Embedding(max_seq_len, dim)
# nn.Embedding
self.emb = ml.Embedding(max_seq_len, dim)
def forward(self, x):
n = torch.arange(x.shape[1], device=x.device)
@ -150,7 +153,8 @@ class RelativePositionBias(nn.Module):
self.causal = causal
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
# nn.Embedding
self.relative_attention_bias = ml.Embedding(num_buckets, heads)
@staticmethod
def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
@ -350,7 +354,8 @@ class RMSScaleShiftNorm(nn.Module):
self.scale = dim ** -0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(dim))
self.scale_shift_process = nn.Linear(dim * 2, dim * 2)
# nn.Linear
self.scale_shift_process = ml.Linear(dim * 2, dim * 2)
def forward(self, x, norm_scale_shift_inp):
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
@ -430,7 +435,8 @@ class GLU(nn.Module):
def __init__(self, dim_in, dim_out, activation):
super().__init__()
self.act = activation
self.proj = nn.Linear(dim_in, dim_out * 2)
# nn.Linear
self.proj = ml.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
@ -455,7 +461,8 @@ class FeedForward(nn.Module):
activation = ReluSquared() if relu_squared else nn.GELU()
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
# nn.Linear
ml.Linear(dim, inner_dim),
activation
) if not glu else GLU(dim, inner_dim, activation)
@ -463,7 +470,8 @@ class FeedForward(nn.Module):
project_in,
nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
# nn.Linear
ml.Linear(inner_dim, dim_out)
)
# init last linear layer to 0
@ -516,16 +524,20 @@ class Attention(nn.Module):
qk_dim = int(collab_compression * qk_dim)
self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim))
self.to_q = nn.Linear(dim, qk_dim, bias=False)
self.to_k = nn.Linear(dim, qk_dim, bias=False)
self.to_v = nn.Linear(dim, v_dim, bias=False)
# nn.Linear
self.to_q = ml.Linear(dim, qk_dim, bias=False)
# nn.Linear
self.to_k = ml.Linear(dim, qk_dim, bias=False)
# nn.Linear
self.to_v = ml.Linear(dim, v_dim, bias=False)
self.dropout = nn.Dropout(dropout)
# add GLU gating for aggregated values, from alphafold2
self.to_v_gate = None
if gate_values:
self.to_v_gate = nn.Linear(dim, v_dim)
# nn.Linear
self.to_v_gate = ml.Linear(dim, v_dim)
nn.init.constant_(self.to_v_gate.weight, 0)
nn.init.constant_(self.to_v_gate.bias, 1)
@ -561,7 +573,8 @@ class Attention(nn.Module):
# attention on attention
self.attn_on_attn = on_attn
self.to_out = nn.Sequential(nn.Linear(v_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(v_dim, dim)
# nn.Linear
self.to_out = nn.Sequential(ml.Linear(v_dim, dim * 2), nn.GLU()) if on_attn else ml.Linear(v_dim, dim)
self.rel_pos_bias = rel_pos_bias
if rel_pos_bias:
@ -1051,7 +1064,8 @@ class ViTransformerWrapper(nn.Module):
self.patch_size = patch_size
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_to_embedding = nn.Linear(patch_dim, dim)
# nn.Linear
self.patch_to_embedding = ml.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
@ -1109,18 +1123,21 @@ class TransformerWrapper(nn.Module):
self.max_mem_len = max_mem_len
self.shift_mem_down = shift_mem_down
self.token_emb = nn.Embedding(num_tokens, emb_dim)
# nn.Embedding
self.token_emb = ml.Embedding(num_tokens, emb_dim)
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
self.emb_dropout = nn.Dropout(emb_dropout)
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
# nn.Linear
self.project_emb = ml.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim)
self.init_()
self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
# nn.Linear
self.to_logits = ml.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
# memory tokens (like [cls]) from Memory Transformers paper
num_memory_tokens = default(num_memory_tokens, 0)
@ -1207,12 +1224,14 @@ class ContinuousTransformerWrapper(nn.Module):
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
self.emb_dropout = nn.Dropout(emb_dropout)
self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
# nn.Linear
self.project_in = ml.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim)
self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
# nn.Linear
self.project_out = ml.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
def forward(
self,

@ -0,0 +1,63 @@
"""
from bitsandbytes.nn import Linear8bitLt as Linear
from bitsandbytes.nn import StableEmbedding as Embedding
from bitsandbytes.optim.adam import Adam8bit as Adam
from bitsandbytes.optim.adamw import AdamW8bit as AdamW
"""
"""
from torch.nn import Linear
from torch.nn import Embedding
from torch.optim.adam import Adam
from torch.optim.adamw import AdamW
"""
"""
OVERRIDE_LINEAR = False
OVERRIDE_EMBEDDING = False
OVERRIDE_ADAM = False # True
OVERRIDE_ADAMW = False # True
"""
import os
USE_STABLE_EMBEDDING = False
try:
import bitsandbytes as bnb
OVERRIDE_LINEAR = False
OVERRIDE_EMBEDDING = True
OVERRIDE_ADAM = True
OVERRIDE_ADAMW = True
USE_STABLE_EMBEDDING = os.environ.get('BITSANDBYTES_USE_STABLE_EMBEDDING', '1' if USE_STABLE_EMBEDDING else '0') == '1'
OVERRIDE_LINEAR = os.environ.get('BITSANDBYTES_OVERRIDE_LINEAR', '1' if OVERRIDE_LINEAR else '0') == '1'
OVERRIDE_EMBEDDING = os.environ.get('BITSANDBYTES_OVERRIDE_EMBEDDING', '1' if OVERRIDE_EMBEDDING else '0') == '1'
OVERRIDE_ADAM = os.environ.get('BITSANDBYTES_OVERRIDE_ADAM', '1' if OVERRIDE_ADAM else '0') == '1'
OVERRIDE_ADAMW = os.environ.get('BITSANDBYTES_OVERRIDE_ADAMW', '1' if OVERRIDE_ADAMW else '0') == '1'
except Exception as e:
OVERRIDE_LINEAR = False
OVERRIDE_EMBEDDING = False
OVERRIDE_ADAM = False
OVERRIDE_ADAMW = False
if OVERRIDE_LINEAR:
from bitsandbytes.nn import Linear8bitLt as Linear
else:
from torch.nn import Linear
if OVERRIDE_EMBEDDING:
if USE_STABLE_EMBEDDING:
from bitsandbytes.nn import StableEmbedding as Embedding
else:
from bitsandbytes.nn.modules import Embedding as Embedding
else:
from torch.nn import Embedding
if OVERRIDE_ADAM:
from bitsandbytes.optim.adam import Adam8bit as Adam
else:
from torch.optim.adam import Adam
if OVERRIDE_ADAMW:
from bitsandbytes.optim.adamw import AdamW8bit as AdamW
else:
from torch.optim.adamw import AdamW

@ -7,6 +7,8 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2FeatureExtractor, Wav2Vec2CTCTo
from tortoise.utils.audio import load_audio
from tortoise.utils.device import get_device
import tortoise.utils.torch_intermediary as ml
def max_alignment(s1, s2, skip_character='~', record=None):
"""
A clever function that aligns s1 to s2 as best it can. Wherever a character from s1 is not found in s2, a '~' is