diff --git a/tortoise/api.py b/tortoise/api.py index 115cea6..490be33 100755 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -221,6 +221,16 @@ class TextToSpeech: if device is None: device = get_device(verbose=True) + try: + import tortoise.utils.torch_intermediary as ml + if ml.OVERRIDE_ADAM: + print("Using BitsAndBytes ADAMW optimizations") + else: + print("NOT using BitsAndBytes ADAMW optimizations") + except Exception as e: + print(e) + pass + self.input_sample_rate = input_sample_rate self.output_sample_rate = output_sample_rate self.minor_optimizations = minor_optimizations diff --git a/tortoise/models/autoregressive.py b/tortoise/models/autoregressive.py index 3f797c0..b07a3b3 100755 --- a/tortoise/models/autoregressive.py +++ b/tortoise/models/autoregressive.py @@ -11,6 +11,7 @@ from tortoise.utils.typical_sampling import TypicalLogitsWarper from tortoise.utils.device import get_device_count +import tortoise.utils.torch_intermediary as ml def null_position_embeddings(range, dim): return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) @@ -221,7 +222,8 @@ class ConditioningEncoder(nn.Module): class LearnedPositionEmbeddings(nn.Module): def __init__(self, seq_len, model_dim, init=.02): super().__init__() - self.emb = nn.Embedding(seq_len, model_dim) + # ml.Embedding + self.emb = ml.Embedding(seq_len, model_dim) # Initializing this way is standard for GPT-2 self.emb.weight.data.normal_(mean=0.0, std=init) @@ -321,9 +323,11 @@ class UnifiedVoice(nn.Module): self.max_conditioning_inputs = max_conditioning_inputs self.mel_length_compression = mel_length_compression self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) - self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim) + # ml.Embedding + self.text_embedding = ml.Embedding(self.number_text_tokens*types+1, model_dim) if use_mel_codes_as_input: - self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim) + # ml.Embedding + self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim) else: self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1) self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \ @@ -336,8 +340,10 @@ class UnifiedVoice(nn.Module): self.text_solo_embedding = 0 self.final_norm = nn.LayerNorm(model_dim) - self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1) - self.mel_head = nn.Linear(model_dim, self.number_mel_codes) + # nn.Linear + self.text_head = ml.Linear(model_dim, self.number_text_tokens*types+1) + # nn.Linear + self.mel_head = ml.Linear(model_dim, self.number_mel_codes) # Initialize the embeddings per the GPT-2 scheme embeddings = [self.text_embedding] diff --git a/tortoise/models/classifier.py b/tortoise/models/classifier.py index f92d99e..86e5911 100644 --- a/tortoise/models/classifier.py +++ b/tortoise/models/classifier.py @@ -3,6 +3,7 @@ import torch.nn as nn from tortoise.models.arch_util import Upsample, Downsample, normalization, zero_module, AttentionBlock +import tortoise.utils.torch_intermediary as ml class ResBlock(nn.Module): def __init__( @@ -124,7 +125,8 @@ class AudioMiniEncoderWithClassifierHead(nn.Module): def __init__(self, classes, distribute_zero_label=True, **kwargs): super().__init__() self.enc = AudioMiniEncoder(**kwargs) - self.head = nn.Linear(self.enc.dim, classes) + # nn.Linear + self.head = ml.Linear(self.enc.dim, classes) self.num_classes = classes self.distribute_zero_label = distribute_zero_label diff --git a/tortoise/models/clvp.py b/tortoise/models/clvp.py index 00f5011..71a744c 100644 --- a/tortoise/models/clvp.py +++ b/tortoise/models/clvp.py @@ -7,6 +7,7 @@ from tortoise.models.arch_util import CheckpointedXTransformerEncoder from tortoise.models.transformer import Transformer from tortoise.models.xtransformers import Encoder +import tortoise.utils.torch_intermediary as ml def exists(val): return val is not None @@ -44,11 +45,15 @@ class CLVP(nn.Module): use_xformers=False, ): super().__init__() - self.text_emb = nn.Embedding(num_text_tokens, dim_text) - self.to_text_latent = nn.Linear(dim_text, dim_latent, bias=False) + # nn.Embedding + self.text_emb = ml.Embedding(num_text_tokens, dim_text) + # nn.Linear + self.to_text_latent = ml.Linear(dim_text, dim_latent, bias=False) - self.speech_emb = nn.Embedding(num_speech_tokens, dim_speech) - self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False) + # nn.Embedding + self.speech_emb = ml.Embedding(num_speech_tokens, dim_speech) + # nn.Linear + self.to_speech_latent = ml.Linear(dim_speech, dim_latent, bias=False) if use_xformers: self.text_transformer = CheckpointedXTransformerEncoder( @@ -93,8 +98,10 @@ class CLVP(nn.Module): self.wav_token_compression = wav_token_compression self.xformers = use_xformers if not use_xformers: - self.text_pos_emb = nn.Embedding(text_seq_len, dim_text) - self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech) + # nn.Embedding + self.text_pos_emb = ml.Embedding(text_seq_len, dim_text) + # nn.Embedding + self.speech_pos_emb = ml.Embedding(num_speech_tokens, dim_speech) def forward( self, diff --git a/tortoise/models/cvvp.py b/tortoise/models/cvvp.py index 544ca47..692ffba 100644 --- a/tortoise/models/cvvp.py +++ b/tortoise/models/cvvp.py @@ -6,6 +6,7 @@ from torch import einsum from tortoise.models.arch_util import AttentionBlock from tortoise.models.xtransformers import ContinuousTransformerWrapper, Encoder +import tortoise.utils.torch_intermediary as ml def exists(val): return val is not None @@ -54,7 +55,8 @@ class CollapsingTransformer(nn.Module): class ConvFormatEmbedding(nn.Module): def __init__(self, *args, **kwargs): super().__init__() - self.emb = nn.Embedding(*args, **kwargs) + # nn.Embedding + self.emb = ml.Embedding(*args, **kwargs) def forward(self, x): y = self.emb(x) @@ -83,7 +85,8 @@ class CVVP(nn.Module): nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1)) self.conditioning_transformer = CollapsingTransformer( model_dim, model_dim, transformer_heads, dropout, conditioning_enc_depth, cond_mask_percentage) - self.to_conditioning_latent = nn.Linear( + # nn.Linear + self.to_conditioning_latent = ml.Linear( latent_dim, latent_dim, bias=False) if mel_codes is None: @@ -93,7 +96,8 @@ class CVVP(nn.Module): self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim) self.speech_transformer = CollapsingTransformer( model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage) - self.to_speech_latent = nn.Linear( + # nn.Linear + self.to_speech_latent = ml.Linear( latent_dim, latent_dim, bias=False) def get_grad_norm_parameter_groups(self): diff --git a/tortoise/models/diffusion_decoder.py b/tortoise/models/diffusion_decoder.py index b383914..9d8bcde 100755 --- a/tortoise/models/diffusion_decoder.py +++ b/tortoise/models/diffusion_decoder.py @@ -10,6 +10,8 @@ from torch import autocast from tortoise.models.arch_util import normalization, AttentionBlock from tortoise.utils.device import get_device_name +import tortoise.utils.torch_intermediary as ml + def is_latent(t): return t.dtype == torch.float @@ -87,7 +89,8 @@ class ResBlock(TimestepBlock): self.emb_layers = nn.Sequential( nn.SiLU(), - nn.Linear( + # nn.Linear + ml.Linear( emb_channels, 2 * self.out_channels if use_scale_shift_norm else self.out_channels, ), @@ -160,16 +163,19 @@ class DiffusionTts(nn.Module): self.inp_block = nn.Conv1d(in_channels, model_channels, 3, 1, 1) self.time_embed = nn.Sequential( - nn.Linear(model_channels, model_channels), + # nn.Linear + ml.Linear(model_channels, model_channels), nn.SiLU(), - nn.Linear(model_channels, model_channels), + # nn.Linear + ml.Linear(model_channels, model_channels), ) # Either code_converter or latent_converter is used, depending on what type of conditioning data is fed. # This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally # complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive # transformer network. - self.code_embedding = nn.Embedding(in_tokens, model_channels) + # nn.Embedding + self.code_embedding = ml.Embedding(in_tokens, model_channels) self.code_converter = nn.Sequential( AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), diff --git a/tortoise/models/random_latent_generator.py b/tortoise/models/random_latent_generator.py index e90ef21..fab2fec 100644 --- a/tortoise/models/random_latent_generator.py +++ b/tortoise/models/random_latent_generator.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +import tortoise.utils.torch_intermediary as ml def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): if bias is not None: @@ -41,7 +42,8 @@ class RandomLatentConverter(nn.Module): def __init__(self, channels): super().__init__() self.layers = nn.Sequential(*[EqualLinear(channels, channels, lr_mul=.1) for _ in range(5)], - nn.Linear(channels, channels)) + # nn.Linear + ml.Linear(channels, channels)) self.channels = channels def forward(self, ref): diff --git a/tortoise/models/transformer.py b/tortoise/models/transformer.py index 707e9eb..ace644b 100644 --- a/tortoise/models/transformer.py +++ b/tortoise/models/transformer.py @@ -6,6 +6,7 @@ from einops import rearrange from rotary_embedding_torch import RotaryEmbedding, broadcat from torch import nn +import tortoise.utils.torch_intermediary as ml # helpers @@ -120,10 +121,12 @@ class FeedForward(nn.Module): def __init__(self, dim, dropout = 0., mult = 4.): super().__init__() self.net = nn.Sequential( - nn.Linear(dim, dim * mult * 2), + # nn.Linear + ml.Linear(dim, dim * mult * 2), GEGLU(), nn.Dropout(dropout), - nn.Linear(dim * mult, dim) + # nn.Linear + ml.Linear(dim * mult, dim) ) def forward(self, x): @@ -142,9 +145,11 @@ class Attention(nn.Module): self.causal = causal - self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + # nn.Linear + self.to_qkv = ml.Linear(dim, inner_dim * 3, bias = False) self.to_out = nn.Sequential( - nn.Linear(inner_dim, dim), + # nn.Linear + ml.Linear(inner_dim, dim), nn.Dropout(dropout) ) diff --git a/tortoise/models/xtransformers.py b/tortoise/models/xtransformers.py index 8be2df4..ade3e90 100644 --- a/tortoise/models/xtransformers.py +++ b/tortoise/models/xtransformers.py @@ -8,6 +8,8 @@ import torch.nn.functional as F from einops import rearrange, repeat from torch import nn, einsum +import tortoise.utils.torch_intermediary as ml + DEFAULT_DIM_HEAD = 64 Intermediates = namedtuple('Intermediates', [ @@ -121,7 +123,8 @@ class AbsolutePositionalEmbedding(nn.Module): def __init__(self, dim, max_seq_len): super().__init__() self.scale = dim ** -0.5 - self.emb = nn.Embedding(max_seq_len, dim) + # nn.Embedding + self.emb = ml.Embedding(max_seq_len, dim) def forward(self, x): n = torch.arange(x.shape[1], device=x.device) @@ -150,7 +153,8 @@ class RelativePositionBias(nn.Module): self.causal = causal self.num_buckets = num_buckets self.max_distance = max_distance - self.relative_attention_bias = nn.Embedding(num_buckets, heads) + # nn.Embedding + self.relative_attention_bias = ml.Embedding(num_buckets, heads) @staticmethod def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128): @@ -350,7 +354,8 @@ class RMSScaleShiftNorm(nn.Module): self.scale = dim ** -0.5 self.eps = eps self.g = nn.Parameter(torch.ones(dim)) - self.scale_shift_process = nn.Linear(dim * 2, dim * 2) + # nn.Linear + self.scale_shift_process = ml.Linear(dim * 2, dim * 2) def forward(self, x, norm_scale_shift_inp): norm = torch.norm(x, dim=-1, keepdim=True) * self.scale @@ -430,7 +435,8 @@ class GLU(nn.Module): def __init__(self, dim_in, dim_out, activation): super().__init__() self.act = activation - self.proj = nn.Linear(dim_in, dim_out * 2) + # nn.Linear + self.proj = ml.Linear(dim_in, dim_out * 2) def forward(self, x): x, gate = self.proj(x).chunk(2, dim=-1) @@ -455,7 +461,8 @@ class FeedForward(nn.Module): activation = ReluSquared() if relu_squared else nn.GELU() project_in = nn.Sequential( - nn.Linear(dim, inner_dim), + # nn.Linear + ml.Linear(dim, inner_dim), activation ) if not glu else GLU(dim, inner_dim, activation) @@ -463,7 +470,8 @@ class FeedForward(nn.Module): project_in, nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(), nn.Dropout(dropout), - nn.Linear(inner_dim, dim_out) + # nn.Linear + ml.Linear(inner_dim, dim_out) ) # init last linear layer to 0 @@ -516,16 +524,20 @@ class Attention(nn.Module): qk_dim = int(collab_compression * qk_dim) self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim)) - self.to_q = nn.Linear(dim, qk_dim, bias=False) - self.to_k = nn.Linear(dim, qk_dim, bias=False) - self.to_v = nn.Linear(dim, v_dim, bias=False) + # nn.Linear + self.to_q = ml.Linear(dim, qk_dim, bias=False) + # nn.Linear + self.to_k = ml.Linear(dim, qk_dim, bias=False) + # nn.Linear + self.to_v = ml.Linear(dim, v_dim, bias=False) self.dropout = nn.Dropout(dropout) # add GLU gating for aggregated values, from alphafold2 self.to_v_gate = None if gate_values: - self.to_v_gate = nn.Linear(dim, v_dim) + # nn.Linear + self.to_v_gate = ml.Linear(dim, v_dim) nn.init.constant_(self.to_v_gate.weight, 0) nn.init.constant_(self.to_v_gate.bias, 1) @@ -561,7 +573,8 @@ class Attention(nn.Module): # attention on attention self.attn_on_attn = on_attn - self.to_out = nn.Sequential(nn.Linear(v_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(v_dim, dim) + # nn.Linear + self.to_out = nn.Sequential(ml.Linear(v_dim, dim * 2), nn.GLU()) if on_attn else ml.Linear(v_dim, dim) self.rel_pos_bias = rel_pos_bias if rel_pos_bias: @@ -1051,7 +1064,8 @@ class ViTransformerWrapper(nn.Module): self.patch_size = patch_size self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) - self.patch_to_embedding = nn.Linear(patch_dim, dim) + # nn.Linear + self.patch_to_embedding = ml.Linear(patch_dim, dim) self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) self.dropout = nn.Dropout(emb_dropout) @@ -1109,18 +1123,21 @@ class TransformerWrapper(nn.Module): self.max_mem_len = max_mem_len self.shift_mem_down = shift_mem_down - self.token_emb = nn.Embedding(num_tokens, emb_dim) + # nn.Embedding + self.token_emb = ml.Embedding(num_tokens, emb_dim) self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( use_pos_emb and not attn_layers.has_pos_emb) else always(0) self.emb_dropout = nn.Dropout(emb_dropout) - self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + # nn.Linear + self.project_emb = ml.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() self.attn_layers = attn_layers self.norm = nn.LayerNorm(dim) self.init_() - self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() + # nn.Linear + self.to_logits = ml.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() # memory tokens (like [cls]) from Memory Transformers paper num_memory_tokens = default(num_memory_tokens, 0) @@ -1207,12 +1224,14 @@ class ContinuousTransformerWrapper(nn.Module): use_pos_emb and not attn_layers.has_pos_emb) else always(0) self.emb_dropout = nn.Dropout(emb_dropout) - self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity() + # nn.Linear + self.project_in = ml.Linear(dim_in, dim) if exists(dim_in) else nn.Identity() self.attn_layers = attn_layers self.norm = nn.LayerNorm(dim) - self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity() + # nn.Linear + self.project_out = ml.Linear(dim, dim_out) if exists(dim_out) else nn.Identity() def forward( self, diff --git a/tortoise/utils/torch_intermediary.py b/tortoise/utils/torch_intermediary.py new file mode 100644 index 0000000..fa509dc --- /dev/null +++ b/tortoise/utils/torch_intermediary.py @@ -0,0 +1,63 @@ +""" +from bitsandbytes.nn import Linear8bitLt as Linear +from bitsandbytes.nn import StableEmbedding as Embedding +from bitsandbytes.optim.adam import Adam8bit as Adam +from bitsandbytes.optim.adamw import AdamW8bit as AdamW +""" +""" +from torch.nn import Linear +from torch.nn import Embedding +from torch.optim.adam import Adam +from torch.optim.adamw import AdamW +""" + +""" +OVERRIDE_LINEAR = False +OVERRIDE_EMBEDDING = False +OVERRIDE_ADAM = False # True +OVERRIDE_ADAMW = False # True +""" + +import os + +USE_STABLE_EMBEDDING = False +try: + import bitsandbytes as bnb + OVERRIDE_LINEAR = False + OVERRIDE_EMBEDDING = True + OVERRIDE_ADAM = True + OVERRIDE_ADAMW = True + + USE_STABLE_EMBEDDING = os.environ.get('BITSANDBYTES_USE_STABLE_EMBEDDING', '1' if USE_STABLE_EMBEDDING else '0') == '1' + OVERRIDE_LINEAR = os.environ.get('BITSANDBYTES_OVERRIDE_LINEAR', '1' if OVERRIDE_LINEAR else '0') == '1' + OVERRIDE_EMBEDDING = os.environ.get('BITSANDBYTES_OVERRIDE_EMBEDDING', '1' if OVERRIDE_EMBEDDING else '0') == '1' + OVERRIDE_ADAM = os.environ.get('BITSANDBYTES_OVERRIDE_ADAM', '1' if OVERRIDE_ADAM else '0') == '1' + OVERRIDE_ADAMW = os.environ.get('BITSANDBYTES_OVERRIDE_ADAMW', '1' if OVERRIDE_ADAMW else '0') == '1' +except Exception as e: + OVERRIDE_LINEAR = False + OVERRIDE_EMBEDDING = False + OVERRIDE_ADAM = False + OVERRIDE_ADAMW = False + +if OVERRIDE_LINEAR: + from bitsandbytes.nn import Linear8bitLt as Linear +else: + from torch.nn import Linear + +if OVERRIDE_EMBEDDING: + if USE_STABLE_EMBEDDING: + from bitsandbytes.nn import StableEmbedding as Embedding + else: + from bitsandbytes.nn.modules import Embedding as Embedding +else: + from torch.nn import Embedding + +if OVERRIDE_ADAM: + from bitsandbytes.optim.adam import Adam8bit as Adam +else: + from torch.optim.adam import Adam + +if OVERRIDE_ADAMW: + from bitsandbytes.optim.adamw import AdamW8bit as AdamW +else: + from torch.optim.adamw import AdamW \ No newline at end of file diff --git a/tortoise/utils/wav2vec_alignment.py b/tortoise/utils/wav2vec_alignment.py index f11835f..35c9bb5 100755 --- a/tortoise/utils/wav2vec_alignment.py +++ b/tortoise/utils/wav2vec_alignment.py @@ -7,6 +7,8 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2FeatureExtractor, Wav2Vec2CTCTo from tortoise.utils.audio import load_audio from tortoise.utils.device import get_device +import tortoise.utils.torch_intermediary as ml + def max_alignment(s1, s2, skip_character='~', record=None): """ A clever function that aligns s1 to s2 as best it can. Wherever a character from s1 is not found in s2, a '~' is