applied the bitsandbytes wrapper to tortoise inference (not sure if it matters)
This commit is contained in:
parent
7cc0250a1a
commit
7b839a4263
|
@ -221,6 +221,16 @@ class TextToSpeech:
|
||||||
if device is None:
|
if device is None:
|
||||||
device = get_device(verbose=True)
|
device = get_device(verbose=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import tortoise.utils.torch_intermediary as ml
|
||||||
|
if ml.OVERRIDE_ADAM:
|
||||||
|
print("Using BitsAndBytes ADAMW optimizations")
|
||||||
|
else:
|
||||||
|
print("NOT using BitsAndBytes ADAMW optimizations")
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
pass
|
||||||
|
|
||||||
self.input_sample_rate = input_sample_rate
|
self.input_sample_rate = input_sample_rate
|
||||||
self.output_sample_rate = output_sample_rate
|
self.output_sample_rate = output_sample_rate
|
||||||
self.minor_optimizations = minor_optimizations
|
self.minor_optimizations = minor_optimizations
|
||||||
|
|
|
@ -11,6 +11,7 @@ from tortoise.utils.typical_sampling import TypicalLogitsWarper
|
||||||
|
|
||||||
from tortoise.utils.device import get_device_count
|
from tortoise.utils.device import get_device_count
|
||||||
|
|
||||||
|
import tortoise.utils.torch_intermediary as ml
|
||||||
|
|
||||||
def null_position_embeddings(range, dim):
|
def null_position_embeddings(range, dim):
|
||||||
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
||||||
|
@ -221,7 +222,8 @@ class ConditioningEncoder(nn.Module):
|
||||||
class LearnedPositionEmbeddings(nn.Module):
|
class LearnedPositionEmbeddings(nn.Module):
|
||||||
def __init__(self, seq_len, model_dim, init=.02):
|
def __init__(self, seq_len, model_dim, init=.02):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.emb = nn.Embedding(seq_len, model_dim)
|
# ml.Embedding
|
||||||
|
self.emb = ml.Embedding(seq_len, model_dim)
|
||||||
# Initializing this way is standard for GPT-2
|
# Initializing this way is standard for GPT-2
|
||||||
self.emb.weight.data.normal_(mean=0.0, std=init)
|
self.emb.weight.data.normal_(mean=0.0, std=init)
|
||||||
|
|
||||||
|
@ -321,9 +323,11 @@ class UnifiedVoice(nn.Module):
|
||||||
self.max_conditioning_inputs = max_conditioning_inputs
|
self.max_conditioning_inputs = max_conditioning_inputs
|
||||||
self.mel_length_compression = mel_length_compression
|
self.mel_length_compression = mel_length_compression
|
||||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
||||||
self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim)
|
# ml.Embedding
|
||||||
|
self.text_embedding = ml.Embedding(self.number_text_tokens*types+1, model_dim)
|
||||||
if use_mel_codes_as_input:
|
if use_mel_codes_as_input:
|
||||||
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
|
# ml.Embedding
|
||||||
|
self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim)
|
||||||
else:
|
else:
|
||||||
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
|
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
|
||||||
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
|
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
|
||||||
|
@ -336,8 +340,10 @@ class UnifiedVoice(nn.Module):
|
||||||
self.text_solo_embedding = 0
|
self.text_solo_embedding = 0
|
||||||
|
|
||||||
self.final_norm = nn.LayerNorm(model_dim)
|
self.final_norm = nn.LayerNorm(model_dim)
|
||||||
self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1)
|
# nn.Linear
|
||||||
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
|
self.text_head = ml.Linear(model_dim, self.number_text_tokens*types+1)
|
||||||
|
# nn.Linear
|
||||||
|
self.mel_head = ml.Linear(model_dim, self.number_mel_codes)
|
||||||
|
|
||||||
# Initialize the embeddings per the GPT-2 scheme
|
# Initialize the embeddings per the GPT-2 scheme
|
||||||
embeddings = [self.text_embedding]
|
embeddings = [self.text_embedding]
|
||||||
|
|
|
@ -3,6 +3,7 @@ import torch.nn as nn
|
||||||
|
|
||||||
from tortoise.models.arch_util import Upsample, Downsample, normalization, zero_module, AttentionBlock
|
from tortoise.models.arch_util import Upsample, Downsample, normalization, zero_module, AttentionBlock
|
||||||
|
|
||||||
|
import tortoise.utils.torch_intermediary as ml
|
||||||
|
|
||||||
class ResBlock(nn.Module):
|
class ResBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -124,7 +125,8 @@ class AudioMiniEncoderWithClassifierHead(nn.Module):
|
||||||
def __init__(self, classes, distribute_zero_label=True, **kwargs):
|
def __init__(self, classes, distribute_zero_label=True, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.enc = AudioMiniEncoder(**kwargs)
|
self.enc = AudioMiniEncoder(**kwargs)
|
||||||
self.head = nn.Linear(self.enc.dim, classes)
|
# nn.Linear
|
||||||
|
self.head = ml.Linear(self.enc.dim, classes)
|
||||||
self.num_classes = classes
|
self.num_classes = classes
|
||||||
self.distribute_zero_label = distribute_zero_label
|
self.distribute_zero_label = distribute_zero_label
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ from tortoise.models.arch_util import CheckpointedXTransformerEncoder
|
||||||
from tortoise.models.transformer import Transformer
|
from tortoise.models.transformer import Transformer
|
||||||
from tortoise.models.xtransformers import Encoder
|
from tortoise.models.xtransformers import Encoder
|
||||||
|
|
||||||
|
import tortoise.utils.torch_intermediary as ml
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
|
@ -44,11 +45,15 @@ class CLVP(nn.Module):
|
||||||
use_xformers=False,
|
use_xformers=False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.text_emb = nn.Embedding(num_text_tokens, dim_text)
|
# nn.Embedding
|
||||||
self.to_text_latent = nn.Linear(dim_text, dim_latent, bias=False)
|
self.text_emb = ml.Embedding(num_text_tokens, dim_text)
|
||||||
|
# nn.Linear
|
||||||
|
self.to_text_latent = ml.Linear(dim_text, dim_latent, bias=False)
|
||||||
|
|
||||||
self.speech_emb = nn.Embedding(num_speech_tokens, dim_speech)
|
# nn.Embedding
|
||||||
self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False)
|
self.speech_emb = ml.Embedding(num_speech_tokens, dim_speech)
|
||||||
|
# nn.Linear
|
||||||
|
self.to_speech_latent = ml.Linear(dim_speech, dim_latent, bias=False)
|
||||||
|
|
||||||
if use_xformers:
|
if use_xformers:
|
||||||
self.text_transformer = CheckpointedXTransformerEncoder(
|
self.text_transformer = CheckpointedXTransformerEncoder(
|
||||||
|
@ -93,8 +98,10 @@ class CLVP(nn.Module):
|
||||||
self.wav_token_compression = wav_token_compression
|
self.wav_token_compression = wav_token_compression
|
||||||
self.xformers = use_xformers
|
self.xformers = use_xformers
|
||||||
if not use_xformers:
|
if not use_xformers:
|
||||||
self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
|
# nn.Embedding
|
||||||
self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
|
self.text_pos_emb = ml.Embedding(text_seq_len, dim_text)
|
||||||
|
# nn.Embedding
|
||||||
|
self.speech_pos_emb = ml.Embedding(num_speech_tokens, dim_speech)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -6,6 +6,7 @@ from torch import einsum
|
||||||
from tortoise.models.arch_util import AttentionBlock
|
from tortoise.models.arch_util import AttentionBlock
|
||||||
from tortoise.models.xtransformers import ContinuousTransformerWrapper, Encoder
|
from tortoise.models.xtransformers import ContinuousTransformerWrapper, Encoder
|
||||||
|
|
||||||
|
import tortoise.utils.torch_intermediary as ml
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
|
@ -54,7 +55,8 @@ class CollapsingTransformer(nn.Module):
|
||||||
class ConvFormatEmbedding(nn.Module):
|
class ConvFormatEmbedding(nn.Module):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.emb = nn.Embedding(*args, **kwargs)
|
# nn.Embedding
|
||||||
|
self.emb = ml.Embedding(*args, **kwargs)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
y = self.emb(x)
|
y = self.emb(x)
|
||||||
|
@ -83,7 +85,8 @@ class CVVP(nn.Module):
|
||||||
nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1))
|
nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1))
|
||||||
self.conditioning_transformer = CollapsingTransformer(
|
self.conditioning_transformer = CollapsingTransformer(
|
||||||
model_dim, model_dim, transformer_heads, dropout, conditioning_enc_depth, cond_mask_percentage)
|
model_dim, model_dim, transformer_heads, dropout, conditioning_enc_depth, cond_mask_percentage)
|
||||||
self.to_conditioning_latent = nn.Linear(
|
# nn.Linear
|
||||||
|
self.to_conditioning_latent = ml.Linear(
|
||||||
latent_dim, latent_dim, bias=False)
|
latent_dim, latent_dim, bias=False)
|
||||||
|
|
||||||
if mel_codes is None:
|
if mel_codes is None:
|
||||||
|
@ -93,7 +96,8 @@ class CVVP(nn.Module):
|
||||||
self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim)
|
self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim)
|
||||||
self.speech_transformer = CollapsingTransformer(
|
self.speech_transformer = CollapsingTransformer(
|
||||||
model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage)
|
model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage)
|
||||||
self.to_speech_latent = nn.Linear(
|
# nn.Linear
|
||||||
|
self.to_speech_latent = ml.Linear(
|
||||||
latent_dim, latent_dim, bias=False)
|
latent_dim, latent_dim, bias=False)
|
||||||
|
|
||||||
def get_grad_norm_parameter_groups(self):
|
def get_grad_norm_parameter_groups(self):
|
||||||
|
|
|
@ -10,6 +10,8 @@ from torch import autocast
|
||||||
from tortoise.models.arch_util import normalization, AttentionBlock
|
from tortoise.models.arch_util import normalization, AttentionBlock
|
||||||
from tortoise.utils.device import get_device_name
|
from tortoise.utils.device import get_device_name
|
||||||
|
|
||||||
|
import tortoise.utils.torch_intermediary as ml
|
||||||
|
|
||||||
def is_latent(t):
|
def is_latent(t):
|
||||||
return t.dtype == torch.float
|
return t.dtype == torch.float
|
||||||
|
|
||||||
|
@ -87,7 +89,8 @@ class ResBlock(TimestepBlock):
|
||||||
|
|
||||||
self.emb_layers = nn.Sequential(
|
self.emb_layers = nn.Sequential(
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Linear(
|
# nn.Linear
|
||||||
|
ml.Linear(
|
||||||
emb_channels,
|
emb_channels,
|
||||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
||||||
),
|
),
|
||||||
|
@ -160,16 +163,19 @@ class DiffusionTts(nn.Module):
|
||||||
|
|
||||||
self.inp_block = nn.Conv1d(in_channels, model_channels, 3, 1, 1)
|
self.inp_block = nn.Conv1d(in_channels, model_channels, 3, 1, 1)
|
||||||
self.time_embed = nn.Sequential(
|
self.time_embed = nn.Sequential(
|
||||||
nn.Linear(model_channels, model_channels),
|
# nn.Linear
|
||||||
|
ml.Linear(model_channels, model_channels),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Linear(model_channels, model_channels),
|
# nn.Linear
|
||||||
|
ml.Linear(model_channels, model_channels),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
|
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
|
||||||
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
|
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
|
||||||
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
|
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
|
||||||
# transformer network.
|
# transformer network.
|
||||||
self.code_embedding = nn.Embedding(in_tokens, model_channels)
|
# nn.Embedding
|
||||||
|
self.code_embedding = ml.Embedding(in_tokens, model_channels)
|
||||||
self.code_converter = nn.Sequential(
|
self.code_converter = nn.Sequential(
|
||||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||||
|
|
|
@ -4,6 +4,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import tortoise.utils.torch_intermediary as ml
|
||||||
|
|
||||||
def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
|
def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
|
@ -41,7 +42,8 @@ class RandomLatentConverter(nn.Module):
|
||||||
def __init__(self, channels):
|
def __init__(self, channels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = nn.Sequential(*[EqualLinear(channels, channels, lr_mul=.1) for _ in range(5)],
|
self.layers = nn.Sequential(*[EqualLinear(channels, channels, lr_mul=.1) for _ in range(5)],
|
||||||
nn.Linear(channels, channels))
|
# nn.Linear
|
||||||
|
ml.Linear(channels, channels))
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
|
|
||||||
def forward(self, ref):
|
def forward(self, ref):
|
||||||
|
|
|
@ -6,6 +6,7 @@ from einops import rearrange
|
||||||
from rotary_embedding_torch import RotaryEmbedding, broadcat
|
from rotary_embedding_torch import RotaryEmbedding, broadcat
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
import tortoise.utils.torch_intermediary as ml
|
||||||
|
|
||||||
# helpers
|
# helpers
|
||||||
|
|
||||||
|
@ -120,10 +121,12 @@ class FeedForward(nn.Module):
|
||||||
def __init__(self, dim, dropout = 0., mult = 4.):
|
def __init__(self, dim, dropout = 0., mult = 4.):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.net = nn.Sequential(
|
self.net = nn.Sequential(
|
||||||
nn.Linear(dim, dim * mult * 2),
|
# nn.Linear
|
||||||
|
ml.Linear(dim, dim * mult * 2),
|
||||||
GEGLU(),
|
GEGLU(),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
nn.Linear(dim * mult, dim)
|
# nn.Linear
|
||||||
|
ml.Linear(dim * mult, dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -142,9 +145,11 @@ class Attention(nn.Module):
|
||||||
|
|
||||||
self.causal = causal
|
self.causal = causal
|
||||||
|
|
||||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
# nn.Linear
|
||||||
|
self.to_qkv = ml.Linear(dim, inner_dim * 3, bias = False)
|
||||||
self.to_out = nn.Sequential(
|
self.to_out = nn.Sequential(
|
||||||
nn.Linear(inner_dim, dim),
|
# nn.Linear
|
||||||
|
ml.Linear(inner_dim, dim),
|
||||||
nn.Dropout(dropout)
|
nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,8 @@ import torch.nn.functional as F
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from torch import nn, einsum
|
from torch import nn, einsum
|
||||||
|
|
||||||
|
import tortoise.utils.torch_intermediary as ml
|
||||||
|
|
||||||
DEFAULT_DIM_HEAD = 64
|
DEFAULT_DIM_HEAD = 64
|
||||||
|
|
||||||
Intermediates = namedtuple('Intermediates', [
|
Intermediates = namedtuple('Intermediates', [
|
||||||
|
@ -121,7 +123,8 @@ class AbsolutePositionalEmbedding(nn.Module):
|
||||||
def __init__(self, dim, max_seq_len):
|
def __init__(self, dim, max_seq_len):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scale = dim ** -0.5
|
self.scale = dim ** -0.5
|
||||||
self.emb = nn.Embedding(max_seq_len, dim)
|
# nn.Embedding
|
||||||
|
self.emb = ml.Embedding(max_seq_len, dim)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
n = torch.arange(x.shape[1], device=x.device)
|
n = torch.arange(x.shape[1], device=x.device)
|
||||||
|
@ -150,7 +153,8 @@ class RelativePositionBias(nn.Module):
|
||||||
self.causal = causal
|
self.causal = causal
|
||||||
self.num_buckets = num_buckets
|
self.num_buckets = num_buckets
|
||||||
self.max_distance = max_distance
|
self.max_distance = max_distance
|
||||||
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
|
# nn.Embedding
|
||||||
|
self.relative_attention_bias = ml.Embedding(num_buckets, heads)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
|
def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
|
||||||
|
@ -350,7 +354,8 @@ class RMSScaleShiftNorm(nn.Module):
|
||||||
self.scale = dim ** -0.5
|
self.scale = dim ** -0.5
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.g = nn.Parameter(torch.ones(dim))
|
self.g = nn.Parameter(torch.ones(dim))
|
||||||
self.scale_shift_process = nn.Linear(dim * 2, dim * 2)
|
# nn.Linear
|
||||||
|
self.scale_shift_process = ml.Linear(dim * 2, dim * 2)
|
||||||
|
|
||||||
def forward(self, x, norm_scale_shift_inp):
|
def forward(self, x, norm_scale_shift_inp):
|
||||||
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
||||||
|
@ -430,7 +435,8 @@ class GLU(nn.Module):
|
||||||
def __init__(self, dim_in, dim_out, activation):
|
def __init__(self, dim_in, dim_out, activation):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.act = activation
|
self.act = activation
|
||||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
# nn.Linear
|
||||||
|
self.proj = ml.Linear(dim_in, dim_out * 2)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||||
|
@ -455,7 +461,8 @@ class FeedForward(nn.Module):
|
||||||
activation = ReluSquared() if relu_squared else nn.GELU()
|
activation = ReluSquared() if relu_squared else nn.GELU()
|
||||||
|
|
||||||
project_in = nn.Sequential(
|
project_in = nn.Sequential(
|
||||||
nn.Linear(dim, inner_dim),
|
# nn.Linear
|
||||||
|
ml.Linear(dim, inner_dim),
|
||||||
activation
|
activation
|
||||||
) if not glu else GLU(dim, inner_dim, activation)
|
) if not glu else GLU(dim, inner_dim, activation)
|
||||||
|
|
||||||
|
@ -463,7 +470,8 @@ class FeedForward(nn.Module):
|
||||||
project_in,
|
project_in,
|
||||||
nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(),
|
nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
nn.Linear(inner_dim, dim_out)
|
# nn.Linear
|
||||||
|
ml.Linear(inner_dim, dim_out)
|
||||||
)
|
)
|
||||||
|
|
||||||
# init last linear layer to 0
|
# init last linear layer to 0
|
||||||
|
@ -516,16 +524,20 @@ class Attention(nn.Module):
|
||||||
qk_dim = int(collab_compression * qk_dim)
|
qk_dim = int(collab_compression * qk_dim)
|
||||||
self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim))
|
self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim))
|
||||||
|
|
||||||
self.to_q = nn.Linear(dim, qk_dim, bias=False)
|
# nn.Linear
|
||||||
self.to_k = nn.Linear(dim, qk_dim, bias=False)
|
self.to_q = ml.Linear(dim, qk_dim, bias=False)
|
||||||
self.to_v = nn.Linear(dim, v_dim, bias=False)
|
# nn.Linear
|
||||||
|
self.to_k = ml.Linear(dim, qk_dim, bias=False)
|
||||||
|
# nn.Linear
|
||||||
|
self.to_v = ml.Linear(dim, v_dim, bias=False)
|
||||||
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
# add GLU gating for aggregated values, from alphafold2
|
# add GLU gating for aggregated values, from alphafold2
|
||||||
self.to_v_gate = None
|
self.to_v_gate = None
|
||||||
if gate_values:
|
if gate_values:
|
||||||
self.to_v_gate = nn.Linear(dim, v_dim)
|
# nn.Linear
|
||||||
|
self.to_v_gate = ml.Linear(dim, v_dim)
|
||||||
nn.init.constant_(self.to_v_gate.weight, 0)
|
nn.init.constant_(self.to_v_gate.weight, 0)
|
||||||
nn.init.constant_(self.to_v_gate.bias, 1)
|
nn.init.constant_(self.to_v_gate.bias, 1)
|
||||||
|
|
||||||
|
@ -561,7 +573,8 @@ class Attention(nn.Module):
|
||||||
|
|
||||||
# attention on attention
|
# attention on attention
|
||||||
self.attn_on_attn = on_attn
|
self.attn_on_attn = on_attn
|
||||||
self.to_out = nn.Sequential(nn.Linear(v_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(v_dim, dim)
|
# nn.Linear
|
||||||
|
self.to_out = nn.Sequential(ml.Linear(v_dim, dim * 2), nn.GLU()) if on_attn else ml.Linear(v_dim, dim)
|
||||||
|
|
||||||
self.rel_pos_bias = rel_pos_bias
|
self.rel_pos_bias = rel_pos_bias
|
||||||
if rel_pos_bias:
|
if rel_pos_bias:
|
||||||
|
@ -1051,7 +1064,8 @@ class ViTransformerWrapper(nn.Module):
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
|
|
||||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||||
self.patch_to_embedding = nn.Linear(patch_dim, dim)
|
# nn.Linear
|
||||||
|
self.patch_to_embedding = ml.Linear(patch_dim, dim)
|
||||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||||
self.dropout = nn.Dropout(emb_dropout)
|
self.dropout = nn.Dropout(emb_dropout)
|
||||||
|
|
||||||
|
@ -1109,18 +1123,21 @@ class TransformerWrapper(nn.Module):
|
||||||
self.max_mem_len = max_mem_len
|
self.max_mem_len = max_mem_len
|
||||||
self.shift_mem_down = shift_mem_down
|
self.shift_mem_down = shift_mem_down
|
||||||
|
|
||||||
self.token_emb = nn.Embedding(num_tokens, emb_dim)
|
# nn.Embedding
|
||||||
|
self.token_emb = ml.Embedding(num_tokens, emb_dim)
|
||||||
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
|
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
|
||||||
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
|
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
|
||||||
self.emb_dropout = nn.Dropout(emb_dropout)
|
self.emb_dropout = nn.Dropout(emb_dropout)
|
||||||
|
|
||||||
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
|
# nn.Linear
|
||||||
|
self.project_emb = ml.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
|
||||||
self.attn_layers = attn_layers
|
self.attn_layers = attn_layers
|
||||||
self.norm = nn.LayerNorm(dim)
|
self.norm = nn.LayerNorm(dim)
|
||||||
|
|
||||||
self.init_()
|
self.init_()
|
||||||
|
|
||||||
self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
|
# nn.Linear
|
||||||
|
self.to_logits = ml.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
|
||||||
|
|
||||||
# memory tokens (like [cls]) from Memory Transformers paper
|
# memory tokens (like [cls]) from Memory Transformers paper
|
||||||
num_memory_tokens = default(num_memory_tokens, 0)
|
num_memory_tokens = default(num_memory_tokens, 0)
|
||||||
|
@ -1207,12 +1224,14 @@ class ContinuousTransformerWrapper(nn.Module):
|
||||||
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
|
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
|
||||||
self.emb_dropout = nn.Dropout(emb_dropout)
|
self.emb_dropout = nn.Dropout(emb_dropout)
|
||||||
|
|
||||||
self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
|
# nn.Linear
|
||||||
|
self.project_in = ml.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
|
||||||
|
|
||||||
self.attn_layers = attn_layers
|
self.attn_layers = attn_layers
|
||||||
self.norm = nn.LayerNorm(dim)
|
self.norm = nn.LayerNorm(dim)
|
||||||
|
|
||||||
self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
|
# nn.Linear
|
||||||
|
self.project_out = ml.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|
63
tortoise/utils/torch_intermediary.py
Normal file
63
tortoise/utils/torch_intermediary.py
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
"""
|
||||||
|
from bitsandbytes.nn import Linear8bitLt as Linear
|
||||||
|
from bitsandbytes.nn import StableEmbedding as Embedding
|
||||||
|
from bitsandbytes.optim.adam import Adam8bit as Adam
|
||||||
|
from bitsandbytes.optim.adamw import AdamW8bit as AdamW
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
from torch.nn import Linear
|
||||||
|
from torch.nn import Embedding
|
||||||
|
from torch.optim.adam import Adam
|
||||||
|
from torch.optim.adamw import AdamW
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
OVERRIDE_LINEAR = False
|
||||||
|
OVERRIDE_EMBEDDING = False
|
||||||
|
OVERRIDE_ADAM = False # True
|
||||||
|
OVERRIDE_ADAMW = False # True
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
USE_STABLE_EMBEDDING = False
|
||||||
|
try:
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
OVERRIDE_LINEAR = False
|
||||||
|
OVERRIDE_EMBEDDING = True
|
||||||
|
OVERRIDE_ADAM = True
|
||||||
|
OVERRIDE_ADAMW = True
|
||||||
|
|
||||||
|
USE_STABLE_EMBEDDING = os.environ.get('BITSANDBYTES_USE_STABLE_EMBEDDING', '1' if USE_STABLE_EMBEDDING else '0') == '1'
|
||||||
|
OVERRIDE_LINEAR = os.environ.get('BITSANDBYTES_OVERRIDE_LINEAR', '1' if OVERRIDE_LINEAR else '0') == '1'
|
||||||
|
OVERRIDE_EMBEDDING = os.environ.get('BITSANDBYTES_OVERRIDE_EMBEDDING', '1' if OVERRIDE_EMBEDDING else '0') == '1'
|
||||||
|
OVERRIDE_ADAM = os.environ.get('BITSANDBYTES_OVERRIDE_ADAM', '1' if OVERRIDE_ADAM else '0') == '1'
|
||||||
|
OVERRIDE_ADAMW = os.environ.get('BITSANDBYTES_OVERRIDE_ADAMW', '1' if OVERRIDE_ADAMW else '0') == '1'
|
||||||
|
except Exception as e:
|
||||||
|
OVERRIDE_LINEAR = False
|
||||||
|
OVERRIDE_EMBEDDING = False
|
||||||
|
OVERRIDE_ADAM = False
|
||||||
|
OVERRIDE_ADAMW = False
|
||||||
|
|
||||||
|
if OVERRIDE_LINEAR:
|
||||||
|
from bitsandbytes.nn import Linear8bitLt as Linear
|
||||||
|
else:
|
||||||
|
from torch.nn import Linear
|
||||||
|
|
||||||
|
if OVERRIDE_EMBEDDING:
|
||||||
|
if USE_STABLE_EMBEDDING:
|
||||||
|
from bitsandbytes.nn import StableEmbedding as Embedding
|
||||||
|
else:
|
||||||
|
from bitsandbytes.nn.modules import Embedding as Embedding
|
||||||
|
else:
|
||||||
|
from torch.nn import Embedding
|
||||||
|
|
||||||
|
if OVERRIDE_ADAM:
|
||||||
|
from bitsandbytes.optim.adam import Adam8bit as Adam
|
||||||
|
else:
|
||||||
|
from torch.optim.adam import Adam
|
||||||
|
|
||||||
|
if OVERRIDE_ADAMW:
|
||||||
|
from bitsandbytes.optim.adamw import AdamW8bit as AdamW
|
||||||
|
else:
|
||||||
|
from torch.optim.adamw import AdamW
|
|
@ -7,6 +7,8 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2FeatureExtractor, Wav2Vec2CTCTo
|
||||||
from tortoise.utils.audio import load_audio
|
from tortoise.utils.audio import load_audio
|
||||||
from tortoise.utils.device import get_device
|
from tortoise.utils.device import get_device
|
||||||
|
|
||||||
|
import tortoise.utils.torch_intermediary as ml
|
||||||
|
|
||||||
def max_alignment(s1, s2, skip_character='~', record=None):
|
def max_alignment(s1, s2, skip_character='~', record=None):
|
||||||
"""
|
"""
|
||||||
A clever function that aligns s1 to s2 as best it can. Wherever a character from s1 is not found in s2, a '~' is
|
A clever function that aligns s1 to s2 as best it can. Wherever a character from s1 is not found in s2, a '~' is
|
||||||
|
|
Loading…
Reference in New Issue
Block a user