diff --git a/codes/models/arch_util.py b/codes/models/arch_util.py index 48041acd..4c9caeb0 100644 --- a/codes/models/arch_util.py +++ b/codes/models/arch_util.py @@ -9,7 +9,7 @@ import torch.nn.utils.spectral_norm as SpectralNorm from math import sqrt from utils.util import checkpoint -import bitsandbytes as bnb +import torch_intermediary as ml def exists(val): @@ -74,7 +74,7 @@ def initialize_weights(net_l, scale=1): m.weight.data *= scale # for residual block if m.bias is not None: m.bias.data.zero_() - elif isinstance(m, bnb.nn.Linear8bitLt): + elif isinstance(m, ml.Linear): init.kaiming_normal_(m.weight, a=0, mode='fan_in') m.weight.data *= scale if m.bias is not None: @@ -109,7 +109,7 @@ def default_init_weights(module, scale=1): if isinstance(m, nn.Conv2d): kaiming_init(m, a=0, mode='fan_in', bias=0) m.weight.data *= scale - elif isinstance(m, bnb.nn.Linear8bitLt): + elif isinstance(m, ml.Linear): kaiming_init(m, a=0, mode='fan_in', bias=0) m.weight.data *= scale @@ -142,7 +142,7 @@ def linear(*args, **kwargs): """ Create a linear module. """ - return bnb.nn.Linear8bitLt(*args, **kwargs) + return ml.Linear(*args, **kwargs) def avg_pool_nd(dims, *args, **kwargs): diff --git a/codes/models/audio/asr/w2v_wrapper.py b/codes/models/audio/asr/w2v_wrapper.py index 2f1ccdde..f25c562e 100644 --- a/codes/models/audio/asr/w2v_wrapper.py +++ b/codes/models/audio/asr/w2v_wrapper.py @@ -9,7 +9,7 @@ from data.audio.unsupervised_audio_dataset import load_audio from models.audio.tts.tacotron2.text import sequence_to_text from trainer.networks import register_model from utils.util import opt_get -import bitsandbytes as bnb +import torch_intermediary as ml def only_letters(string): @@ -52,7 +52,7 @@ class Wav2VecWrapper(nn.Module): self.w2v = Wav2Vec2ForCTC.from_pretrained(basis_model) # Perform some surgery to get the model we actually want. self.w2v.wav2vec2.encoder.gradient_checkpointing = checkpointing_enabled - self.w2v.lm_head = bnb.nn.Linear8bitLt(self.w2v.config.hidden_size, vocab_size) + self.w2v.lm_head = ml.Linear(self.w2v.config.hidden_size, vocab_size) self.w2v.config.vocab_size = vocab_size self.w2v.config.pad_token_id = 0 self.w2v.config.ctc_loss_reduction = 'sum' diff --git a/codes/models/audio/audio_resnet.py b/codes/models/audio/audio_resnet.py index 9c4d63df..59e8b2fa 100644 --- a/codes/models/audio/audio_resnet.py +++ b/codes/models/audio/audio_resnet.py @@ -5,7 +5,7 @@ import torch.nn as nn from trainer.networks import register_model from utils.util import opt_get from typing import Type, Any, Callable, Union, List, Optional -import bitsandbytes as bnb +import torch_intermediary as ml __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', @@ -173,7 +173,7 @@ class ResNet(nn.Module): self.layer4 = self._make_layer(block, 512, layers[3], stride=4, dilate=replace_stride_with_dilation[2]) self.avgpool = nn.AdaptiveAvgPool1d(1) - self.fc = bnb.nn.Linear8bitLt(512 * block.expansion, num_classes) + self.fc = ml.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv1d): diff --git a/codes/models/audio/mel2vec.py b/codes/models/audio/mel2vec.py index bac82b8f..25a00a81 100644 --- a/codes/models/audio/mel2vec.py +++ b/codes/models/audio/mel2vec.py @@ -15,14 +15,14 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled from models.arch_util import ResBlock from trainer.networks import register_model from utils.util import checkpoint -import bitsandbytes as bnb +import torch_intermediary as ml class Mel2Vec2FeatureProjection(nn.Module): def __init__(self, inner_dim, dropout): super().__init__() self.layer_norm = nn.LayerNorm(inner_dim, eps=1e-5) - self.projection = bnb.nn.Linear8bitLt(inner_dim, inner_dim) + self.projection = ml.Linear(inner_dim, inner_dim) self.dropout = nn.Dropout(dropout) def forward(self, hidden_states): @@ -59,10 +59,10 @@ class Wav2Vec2Attention(nn.Module): self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder - self.k_proj = bnb.nn.Linear8bitLt(embed_dim, embed_dim, bias=bias) - self.v_proj = bnb.nn.Linear8bitLt(embed_dim, embed_dim, bias=bias) - self.q_proj = bnb.nn.Linear8bitLt(embed_dim, embed_dim, bias=bias) - self.out_proj = bnb.nn.Linear8bitLt(embed_dim, embed_dim, bias=bias) + self.k_proj = ml.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = ml.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = ml.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = ml.Linear(embed_dim, embed_dim, bias=bias) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -183,10 +183,10 @@ class Wav2Vec2FeedForward(nn.Module): super().__init__() self.intermediate_dropout = nn.Dropout(dropout) - self.intermediate_dense = bnb.nn.Linear8bitLt(hidden_size, intermediate_size) + self.intermediate_dense = ml.Linear(hidden_size, intermediate_size) self.intermediate_act_fn = F.gelu - self.output_dense = bnb.nn.Linear8bitLt(intermediate_size, hidden_size) + self.output_dense = ml.Linear(intermediate_size, hidden_size) self.output_dropout = nn.Dropout(dropout) def forward(self, hidden_states): @@ -430,7 +430,7 @@ class Mel2Vec(nn.Module): k = math.sqrt(1 / module.projection.in_features) nn.init.uniform_(module.projection.weight, a=-k, b=k) nn.init.uniform_(module.projection.bias, a=-k, b=k) - elif isinstance(module, bnb.nn.Linear8bitLt): + elif isinstance(module, ml.Linear): if self.disable_custom_linear_init: return module.weight.data.normal_(mean=0.0, std=self.linear_init_scale) @@ -511,7 +511,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): self.codevectors = nn.Parameter( torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups) ) - self.weight_proj = bnb.nn.Linear8bitLt(proj_dim, self.num_groups * self.num_vars) + self.weight_proj = ml.Linear(proj_dim, self.num_groups * self.num_vars) # can be decayed for training self.temperature = 2 @@ -607,8 +607,8 @@ class ContrastiveTrainingWrapper(nn.Module): self.inp_length_factor = inp_length_multiplier # make sure that project_hid & project_q are initialized like normal linear layers - self.project_hid = bnb.nn.Linear8bitLt(inner_dim, self.quantizer.codevector_dim) - self.project_q = bnb.nn.Linear8bitLt(self.quantizer.codevector_dim, self.quantizer.codevector_dim) + self.project_hid = ml.Linear(inner_dim, self.quantizer.codevector_dim) + self.project_q = ml.Linear(self.quantizer.codevector_dim, self.quantizer.codevector_dim) self.reconstruction = do_reconstruction_loss if do_reconstruction_loss: diff --git a/codes/models/audio/music/cheater_gen_ar.py b/codes/models/audio/music/cheater_gen_ar.py index 9b1c1d9d..00222217 100644 --- a/codes/models/audio/music/cheater_gen_ar.py +++ b/codes/models/audio/music/cheater_gen_ar.py @@ -2,7 +2,7 @@ import torch import torch.nn.functional as F from torch import nn from transformers import GPT2Config, GPT2Model -import bitsandbytes as bnb +import torch_intermediary as ml from models.arch_util import AttentionBlock, ResBlock from models.audio.tts.lucidrains_dvae import DiscreteVAE @@ -57,8 +57,8 @@ class ConditioningAR(nn.Module): del self.gpt.wte # Unused, we'll do our own embeddings. # nn.Embedding - self.embeddings = bnb.nn.StableEmbedding(num_vectors, dim) - self.head = bnb.nn.Linear8bitLt(dim, num_vectors) + self.embeddings = ml.Embedding(num_vectors, dim) + self.head = ml.Linear(dim, num_vectors) def forward(self, cheater_codes, conditioning, code_lengths=None, return_latent=False): unused_params = [] diff --git a/codes/models/audio/music/diffwave.py b/codes/models/audio/music/diffwave.py index cf25bc19..71cf765c 100644 --- a/codes/models/audio/music/diffwave.py +++ b/codes/models/audio/music/diffwave.py @@ -17,7 +17,7 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -import bitsandbytes as bnb +import torch_intermediary as ml from math import sqrt @@ -25,7 +25,7 @@ from torch.utils.checkpoint import checkpoint from trainer.networks import register_model -Linear = bnb.nn.Linear8bitLt +Linear = ml.Linear ConvTranspose2d = nn.ConvTranspose2d diff --git a/codes/models/audio/music/flat_diffusion.py b/codes/models/audio/music/flat_diffusion.py index cbc52a7d..9b8d897b 100644 --- a/codes/models/audio/music/flat_diffusion.py +++ b/codes/models/audio/music/flat_diffusion.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch import autocast -import bitsandbytes as bnb +import torch_intermediary as ml from models.arch_util import ResBlock from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear @@ -24,7 +24,7 @@ class MultiGroupEmbedding(nn.Module): def __init__(self, tokens, groups, dim): super().__init__() # nn.Embedding - self.m = nn.ModuleList([bnb.nn.StableEmbedding(tokens, dim // groups) for _ in range(groups)]) + self.m = nn.ModuleList([ml.Embedding(tokens, dim // groups) for _ in range(groups)]) def forward(self, x): h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)] @@ -161,7 +161,7 @@ class FlatDiffusion(nn.Module): # transformer network. if in_groups is None: # nn.Embedding - self.embeddings = bnb.nn.StableEmbedding(token_count, model_channels) + self.embeddings = ml.Embedding(token_count, model_channels) else: self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels) self.latent_conditioner = nn.Sequential( diff --git a/codes/models/audio/music/gpt_music.py b/codes/models/audio/music/gpt_music.py index b2281662..03322df9 100644 --- a/codes/models/audio/music/gpt_music.py +++ b/codes/models/audio/music/gpt_music.py @@ -2,7 +2,7 @@ import torch from torch import nn import torch.nn.functional as F from transformers import GPT2Config, GPT2Model -import bitsandbytes as bnb +import torch_intermediary as ml from models.arch_util import AttentionBlock, ResBlock from models.audio.music.music_quantizer import MusicQuantizer @@ -138,8 +138,8 @@ class GptMusicLower(nn.Module): del self.gpt.wte # Unused, we'll do our own embeddings. # nn.Embedding - self.embeddings = nn.ModuleList([bnb.nn.StableEmbedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)]) - self.heads = nn.ModuleList([bnb.nn.Linear8bitLt(dim, num_target_vectors) for _ in range(num_vaes)]) + self.embeddings = nn.ModuleList([ml.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)]) + self.heads = nn.ModuleList([ml.Linear(dim, num_target_vectors) for _ in range(num_vaes)]) def forward(self, mel, conditioning, return_latent=False): unused_params = [] @@ -241,8 +241,8 @@ class GptMusicUpper(nn.Module): del self.gpt.wte # Unused, we'll do our own embeddings. # nn.Embedding - self.embeddings = nn.ModuleList([bnb.nn.StableEmbedding(num_upper_vectors, dim // num_upper_groups) for _ in range(num_upper_groups)]) - self.heads = nn.ModuleList([bnb.nn.Linear8bitLt(dim, num_upper_vectors) for _ in range(num_upper_groups)]) + self.embeddings = nn.ModuleList([ml.Embedding(num_upper_vectors, dim // num_upper_groups) for _ in range(num_upper_groups)]) + self.heads = nn.ModuleList([ml.Linear(dim, num_upper_vectors) for _ in range(num_upper_groups)]) def forward(self, mel, conditioning, return_latent=False): diff --git a/codes/models/audio/music/gpt_music2.py b/codes/models/audio/music/gpt_music2.py index 719d6317..d508eb70 100644 --- a/codes/models/audio/music/gpt_music2.py +++ b/codes/models/audio/music/gpt_music2.py @@ -2,7 +2,7 @@ import torch import torch.nn.functional as F from torch import nn from transformers import GPT2Config, GPT2Model -import bitsandbytes as bnb +import torch_intermediary as ml from models.arch_util import AttentionBlock, ResBlock from models.audio.tts.lucidrains_dvae import DiscreteVAE @@ -75,8 +75,8 @@ class GptMusicLower(nn.Module): del self.gpt.wte # Unused, we'll do our own embeddings. # nn.Embedding - self.embeddings = nn.ModuleList([bnb.nn.StableEmbedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)]) - self.heads = nn.ModuleList([bnb.nn.Linear8bitLt(dim, num_target_vectors) for _ in range(num_vaes)]) + self.embeddings = nn.ModuleList([ml.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)]) + self.heads = nn.ModuleList([ml.Linear(dim, num_target_vectors) for _ in range(num_vaes)]) def forward(self, mel, return_latent=False): unused_params = [] diff --git a/codes/models/audio/music/instrument_quantizer.py b/codes/models/audio/music/instrument_quantizer.py index 0d3a9c98..42b5706a 100644 --- a/codes/models/audio/music/instrument_quantizer.py +++ b/codes/models/audio/music/instrument_quantizer.py @@ -3,7 +3,7 @@ import functools import torch import torch.nn as nn import torch.nn.functional as F -import bitsandbytes as bnb +import torch_intermediary as ml from models.diffusion.nn import timestep_embedding from models.lucidrains.vq import VectorQuantize @@ -22,8 +22,8 @@ class SelfClassifyingHead(nn.Module): use_rmsnorm=True, ff_glu=True, do_checkpointing=False) self.quantizer = VectorQuantize(out_dim, classes, use_cosine_sim=False, threshold_ema_dead_code=2, sample_codebook_temp=init_temperature) - self.to_output = bnb.nn.Linear8bitLt(dim, out_dim) - self.to_decoder = bnb.nn.Linear8bitLt(out_dim, dim) + self.to_output = ml.Linear(dim, out_dim) + self.to_decoder = ml.Linear(out_dim, dim) def do_ar_step(self, x, used_codes): h = self.dec(x) @@ -91,7 +91,7 @@ class InstrumentQuantizer(nn.Module): """ super().__init__() self.op_dim = op_dim - self.proj = bnb.nn.Linear8bitLt(op_dim, dim) + self.proj = ml.Linear(op_dim, dim) self.encoder = nn.ModuleList([VectorResBlock(dim, dropout) for _ in range(enc_depth)]) self.heads = SelfClassifyingHead(dim, num_classes, op_dim, head_depth, class_seq_len, dropout, max_temp) self.min_gumbel_temperature = min_temp diff --git a/codes/models/audio/music/mel2vec_codes_gpt.py b/codes/models/audio/music/mel2vec_codes_gpt.py index 74c3df7f..62c60768 100644 --- a/codes/models/audio/music/mel2vec_codes_gpt.py +++ b/codes/models/audio/music/mel2vec_codes_gpt.py @@ -1,7 +1,7 @@ import torch from torch import nn import torch.nn.functional as F -import bitsandbytes as bnb +import torch_intermediary as ml from transformers import GPT2Config, GPT2Model from trainer.networks import register_model @@ -19,8 +19,8 @@ class Mel2VecCodesGpt(nn.Module): self.gpt = GPT2Model(self.config) del self.gpt.wte # Unused, we'll do our own embeddings. # nn.Embedding - self.embeddings = nn.ModuleList([bnb.nn.StableEmbedding(num_vectors, dim//num_groups) for _ in range(num_groups)]) - self.heads = nn.ModuleList([bnb.nn.Linear8bitLt(dim, num_vectors) for _ in range(num_groups)]) + self.embeddings = nn.ModuleList([ml.Embedding(num_vectors, dim//num_groups) for _ in range(num_groups)]) + self.heads = nn.ModuleList([ml.Linear(dim, num_vectors) for _ in range(num_groups)]) def forward(self, codes): assert codes.shape[-1] == self.num_groups diff --git a/codes/models/audio/music/music_quantizer.py b/codes/models/audio/music/music_quantizer.py index 75bf2c2d..3d3b461b 100644 --- a/codes/models/audio/music/music_quantizer.py +++ b/codes/models/audio/music/music_quantizer.py @@ -3,7 +3,7 @@ import functools import torch from torch import nn import torch.nn.functional as F -import bitsandbytes as bnb +import torch_intermediary as ml from models.arch_util import zero_module from models.vqvae.vqvae import Quantize @@ -76,7 +76,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): self.codevectors = nn.Parameter( torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups) ) - self.weight_proj = bnb.nn.Linear8bitLt(proj_dim, self.num_groups * self.num_vars) + self.weight_proj = ml.Linear(proj_dim, self.num_groups * self.num_vars) # can be decayed for training self.temperature = 2 diff --git a/codes/models/audio/music/music_quantizer2.py b/codes/models/audio/music/music_quantizer2.py index 585cf67c..d7df3658 100644 --- a/codes/models/audio/music/music_quantizer2.py +++ b/codes/models/audio/music/music_quantizer2.py @@ -3,7 +3,7 @@ import functools import torch from torch import nn import torch.nn.functional as F -import bitsandbytes as bnb +import torch_intermediary as ml from models.arch_util import zero_module from models.vqvae.vqvae import Quantize @@ -88,7 +88,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): self.codevectors = nn.Parameter( torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups) ) - self.weight_proj = bnb.nn.Linear8bitLt(proj_dim, self.num_groups * self.num_vars) + self.weight_proj = ml.Linear(proj_dim, self.num_groups * self.num_vars) # can be decayed for training self.temperature = 2 diff --git a/codes/models/audio/music/tfdpc_v5.py b/codes/models/audio/music/tfdpc_v5.py index 161c1851..998906df 100644 --- a/codes/models/audio/music/tfdpc_v5.py +++ b/codes/models/audio/music/tfdpc_v5.py @@ -8,7 +8,7 @@ import torch.nn as nn import torch.nn.functional as F import torchaudio import torchvision -import bitsandbytes as bnb +import torch_intermediary as ml from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.unet_diffusion import TimestepBlock @@ -56,12 +56,12 @@ class ConcatAttentionBlock(TimestepBlock): self.prenorm = RMSScaleShiftNorm(trunk_dim, embed_dim=time_embed_dim, bias=False) if cond_projection: self.tdim = trunk_dim+cond_dim_hidden - self.cond_project = bnb.nn.Linear8bitLt(cond_dim_in, cond_dim_hidden) + self.cond_project = ml.Linear(cond_dim_in, cond_dim_hidden) else: self.tdim = trunk_dim self.block1 = SubBlock(self.tdim, contraction_dim, heads, dropout, use_conv) self.block2 = SubBlock(self.tdim+contraction_dim*2, contraction_dim, heads, dropout, use_conv) - self.out = bnb.nn.Linear8bitLt(contraction_dim*4, trunk_dim, bias=False) + self.out = ml.Linear(contraction_dim*4, trunk_dim, bias=False) self.out.weight.data.zero_() def forward(self, x, cond, timestep_emb, rotary_emb): @@ -89,7 +89,7 @@ class ConditioningEncoder(nn.Module): self.init = nn.Conv1d(cond_dim, embedding_dim, kernel_size=1) self.time_proj = time_proj if time_proj: - self.time_proj = bnb.nn.Linear8bitLt(time_embed_dim, embedding_dim) + self.time_proj = ml.Linear(time_embed_dim, embedding_dim) self.attn = Encoder( dim=embedding_dim, depth=attn_blocks, diff --git a/codes/models/audio/music/transformer_diffusion12.py b/codes/models/audio/music/transformer_diffusion12.py index f3c60b53..ff3ccc90 100644 --- a/codes/models/audio/music/transformer_diffusion12.py +++ b/codes/models/audio/music/transformer_diffusion12.py @@ -4,7 +4,7 @@ from time import time import torch import torch.nn as nn import torch.nn.functional as F -import bitsandbytes as bnb +import torch_intermediary as ml from models.arch_util import ResBlock from models.audio.music.gpt_music2 import UpperEncoder, GptMusicLower @@ -29,7 +29,7 @@ class MultiGroupEmbedding(nn.Module): def __init__(self, tokens, groups, dim): super().__init__() # nn.Embedding - self.m = nn.ModuleList([bnb.nn.StableEmbedding(tokens, dim // groups) for _ in range(groups)]) + self.m = nn.ModuleList([ml.Embedding(tokens, dim // groups) for _ in range(groups)]) def forward(self, x): h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)] @@ -70,7 +70,7 @@ class ConcatAttentionBlock(TimestepBlock): self.prenorm = RMSScaleShiftNorm(trunk_dim, embed_dim=time_embed_dim, bias=False) self.block1 = SubBlock(trunk_dim, contraction_dim, heads, dropout) self.block2 = SubBlock(trunk_dim+contraction_dim*2, contraction_dim, heads, dropout) - self.out = bnb.nn.Linear8bitLt(contraction_dim*4, trunk_dim, bias=False) + self.out = ml.Linear(contraction_dim*4, trunk_dim, bias=False) self.out.weight.data.zero_() def forward(self, x, timestep_emb, rotary_emb): @@ -131,7 +131,7 @@ class TransformerDiffusion(nn.Module): ) prenet_heads = prenet_channels//64 - self.input_converter = bnb.nn.Linear8bitLt(input_vec_dim, prenet_channels) + self.input_converter = ml.Linear(input_vec_dim, prenet_channels) self.code_converter = Encoder( dim=prenet_channels, depth=prenet_layers, @@ -147,7 +147,7 @@ class TransformerDiffusion(nn.Module): self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels)) self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) - self.intg = bnb.nn.Linear8bitLt(prenet_channels*2, model_channels) + self.intg = ml.Linear(prenet_channels*2, model_channels) self.layers = TimestepRotaryEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_embed_dim, num_heads, dropout) for _ in range(num_layers)]) self.out = nn.Sequential( diff --git a/codes/models/audio/music/transformer_diffusion13.py b/codes/models/audio/music/transformer_diffusion13.py index 2cecb011..c946af4e 100644 --- a/codes/models/audio/music/transformer_diffusion13.py +++ b/codes/models/audio/music/transformer_diffusion13.py @@ -5,7 +5,7 @@ from random import randrange import torch import torch.nn as nn import torch.nn.functional as F -import bitsandbytes as bnb +import torch_intermediary as ml from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask, cGLU, \ RelativeQKBias @@ -71,13 +71,13 @@ class ConditioningEncoder(nn.Module): attn = [] self.init = nn.Conv1d(spec_dim, hidden_dim, kernel_size=5, stride=2) # nn.Embedding - self.resolution_embedding = bnb.nn.StableEmbedding(num_resolutions, hidden_dim) + self.resolution_embedding = ml.Embedding(num_resolutions, hidden_dim) self.resolution_embedding.weight.data.mul(.1) # Reduces the relative influence of this embedding from the start. for a in range(attn_blocks): attn.append(AttentionBlock(hidden_dim, num_attn_heads, do_checkpoint=do_checkpointing)) attn.append(ResBlock(hidden_dim, dims=1, checkpointing_enabled=do_checkpointing)) self.attn = nn.Sequential(*attn) - self.out = bnb.nn.Linear8bitLt(hidden_dim, out_dim, bias=False) + self.out = ml.Linear(hidden_dim, out_dim, bias=False) self.dim = hidden_dim self.do_checkpointing = do_checkpointing @@ -134,7 +134,7 @@ class TransformerDiffusion(nn.Module): linear(time_embed_dim, time_proj_dim), ) # nn.Embedding - self.resolution_embed = bnb.nn.StableEmbedding(resolution_steps, time_proj_dim) + self.resolution_embed = ml.Embedding(resolution_steps, time_proj_dim) self.conditioning_encoder = ConditioningEncoder(in_channels, model_channels, cond_proj_dim, resolution_steps, num_attn_heads=model_channels//64) self.unconditioned_embedding = nn.Parameter(torch.randn(1,cond_proj_dim)) diff --git a/codes/models/audio/music/unet_diffusion_music_codes.py b/codes/models/audio/music/unet_diffusion_music_codes.py index 47293066..5a135502 100644 --- a/codes/models/audio/music/unet_diffusion_music_codes.py +++ b/codes/models/audio/music/unet_diffusion_music_codes.py @@ -8,7 +8,7 @@ import torch as th import torch.nn as nn import torch.nn.functional as F import torchvision # For debugging, not actually used. -import bitsandbytes as bnb +import torch_intermediary as ml from models.audio.music.gpt_music import GptMusicLower from models.audio.music.music_quantizer import MusicQuantizer @@ -491,7 +491,7 @@ class UNetMusicModel(nn.Module): ) if self.ar_prior: - self.ar_input = bnb.nn.Linear8bitLt(input_vec_dim, model_channels) + self.ar_input = ml.Linear(input_vec_dim, model_channels) self.ar_prior_intg = Encoder( dim=model_channels, depth=4, @@ -505,7 +505,7 @@ class UNetMusicModel(nn.Module): ff_mult=1, ) else: - self.input_converter = bnb.nn.Linear8bitLt(input_vec_dim, model_channels) + self.input_converter = ml.Linear(input_vec_dim, model_channels) self.code_converter = Encoder( dim=model_channels, depth=4, @@ -523,7 +523,7 @@ class UNetMusicModel(nn.Module): if self.num_classes is not None: # nn.Embedding - self.label_emb = bnb.nn.StableEmbedding(num_classes, time_embed_dim) + self.label_emb = ml.Embedding(num_classes, time_embed_dim) self.use_raw_y_as_embedding = use_raw_y_as_embedding assert not ((self.num_classes is not None) and use_raw_y_as_embedding) # These are mutually-exclusive. diff --git a/codes/models/audio/tts/ctc_code_generator.py b/codes/models/audio/tts/ctc_code_generator.py index 1b3f63bd..801c3fa2 100644 --- a/codes/models/audio/tts/ctc_code_generator.py +++ b/codes/models/audio/tts/ctc_code_generator.py @@ -3,7 +3,7 @@ from random import random import torch import torch.nn as nn import torch.nn.functional as F -import bitsandbytes as bnb +import torch_intermediary as ml from models.audio.tts.unet_diffusion_tts7 import CheckpointedLayer from models.lucidrains.x_transformers import Encoder @@ -38,11 +38,11 @@ class CtcCodeGenerator(nn.Module): pred_codes = (max_pad+1)*(max_repeat+1) # nn.Embedding - self.position_embedding = bnb.nn.StableEmbedding(max_length, model_dim) + self.position_embedding = ml.Embedding(max_length, model_dim) # nn.Embedding - self.codes_embedding = bnb.nn.StableEmbedding(ctc_codes, model_dim) + self.codes_embedding = ml.Embedding(ctc_codes, model_dim) # nn.Embedding - self.recursive_embedding = bnb.nn.StableEmbedding(pred_codes, model_dim) + self.recursive_embedding = ml.Embedding(pred_codes, model_dim) self.mask_embedding = nn.Parameter(torch.randn(model_dim)) self.encoder = Encoder( dim=model_dim, @@ -54,8 +54,8 @@ class CtcCodeGenerator(nn.Module): ff_glu=True, rotary_pos_emb=True, ) - self.pred_head = bnb.nn.Linear8bitLt(model_dim, pred_codes) - self.confidence_head = bnb.nn.Linear8bitLt(model_dim, 1) + self.pred_head = ml.Linear(model_dim, pred_codes) + self.confidence_head = ml.Linear(model_dim, 1) def inference(self, codes, pads, repeats): position_h = self.position_embedding(torch.arange(0, codes.shape[-1], device=codes.device)) diff --git a/codes/models/audio/tts/diffusion_encoder.py b/codes/models/audio/tts/diffusion_encoder.py index 29eddccc..15c3a2f2 100644 --- a/codes/models/audio/tts/diffusion_encoder.py +++ b/codes/models/audio/tts/diffusion_encoder.py @@ -5,7 +5,7 @@ from functools import partial import torch import torch.nn as nn -import bitsandbytes as bnb +import torch_intermediary as ml from x_transformers.x_transformers import groupby_prefix_and_trim, FixedPositionalEmbedding, default, RotaryEmbedding, \ DEFAULT_DIM_HEAD, RelativePositionBias, LearnedAlibiPositionalBias, AlibiPositionalBias, ScaleNorm, RMSNorm, \ @@ -18,7 +18,7 @@ class TimeIntegrationBlock(nn.Module): super().__init__() self.emb_layers = nn.Sequential( nn.SiLU(), - bnb.nn.Linear8bitLt( + ml.Linear( time_emb_dim, 2 * dim ), diff --git a/codes/models/audio/tts/mini_encoder.py b/codes/models/audio/tts/mini_encoder.py index e8fe5498..4a61199b 100644 --- a/codes/models/audio/tts/mini_encoder.py +++ b/codes/models/audio/tts/mini_encoder.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -import bitsandbytes as bnb +import torch_intermediary as ml from models.diffusion.nn import normalization, conv_nd, zero_module @@ -139,7 +139,7 @@ class AudioMiniEncoderWithClassifierHead(nn.Module): def __init__(self, classes, distribute_zero_label=True, **kwargs): super().__init__() self.enc = AudioMiniEncoder(**kwargs) - self.head = bnb.nn.Linear8bitLt(self.enc.dim, classes) + self.head = ml.Linear(self.enc.dim, classes) self.num_classes = classes self.distribute_zero_label = distribute_zero_label @@ -184,7 +184,7 @@ class QueryProvidedAttentionBlock(nn.Module): ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" self.num_heads = channels // num_head_channels self.norm = normalization(channels) - self.q = bnb.nn.Linear8bitLt(channels, channels) + self.q = ml.Linear(channels, channels) self.qnorm = nn.LayerNorm(channels) self.kv = conv_nd(1, channels, channels*2, 1) if use_new_attention_order: diff --git a/codes/models/audio/tts/random_latent_converter.py b/codes/models/audio/tts/random_latent_converter.py index c5331786..9f14ebb3 100644 --- a/codes/models/audio/tts/random_latent_converter.py +++ b/codes/models/audio/tts/random_latent_converter.py @@ -3,7 +3,7 @@ import math import torch import torch.nn as nn import torch.nn.functional as F -import bitsandbytes as bnb +import torch_intermediary as ml from trainer.networks import register_model from utils.util import opt_get @@ -45,7 +45,7 @@ class RandomLatentConverter(nn.Module): def __init__(self, channels): super().__init__() self.layers = nn.Sequential(*[EqualLinear(channels, channels, lr_mul=.1) for _ in range(5)], - bnb.nn.Linear8bitLt(channels, channels)) + ml.Linear(channels, channels)) self.channels = channels def forward(self, ref): diff --git a/codes/models/audio/tts/tacotron2/layers.py b/codes/models/audio/tts/tacotron2/layers.py index 925ddffd..11022c02 100644 --- a/codes/models/audio/tts/tacotron2/layers.py +++ b/codes/models/audio/tts/tacotron2/layers.py @@ -3,13 +3,13 @@ from librosa.filters import mel as librosa_mel_fn from models.audio.tts.tacotron2.audio_processing import dynamic_range_compression from models.audio.tts.tacotron2.audio_processing import dynamic_range_decompression from models.audio.tts.tacotron2.stft import STFT -import bitsandbytes as bnb +import torch_intermediary as ml class LinearNorm(torch.nn.Module): def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): super(LinearNorm, self).__init__() - self.linear_layer = torch.bnb.nn.Linear8bitLt(in_dim, out_dim, bias=bias) + self.linear_layer = torch.ml.Linear(in_dim, out_dim, bias=bias) torch.nn.init.xavier_uniform_( self.linear_layer.weight, diff --git a/codes/models/audio/tts/tacotron2/tacotron2.py b/codes/models/audio/tts/tacotron2/tacotron2.py index 13e9bef1..f8f64cb7 100644 --- a/codes/models/audio/tts/tacotron2/tacotron2.py +++ b/codes/models/audio/tts/tacotron2/tacotron2.py @@ -8,7 +8,7 @@ from models.audio.tts.tacotron2.layers import ConvNorm, LinearNorm from models.audio.tts.tacotron2.hparams import create_hparams from trainer.networks import register_model from models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths -import bitsandbytes as bnb +import torch_intermediary as ml class LocationLayer(nn.Module): @@ -465,7 +465,7 @@ class Tacotron2(nn.Module): self.n_mel_channels = hparams.n_mel_channels self.n_frames_per_step = hparams.n_frames_per_step # nn.Embedding - self.embedding = bnb.nn.StableEmbedding( + self.embedding = ml.Embedding( hparams.n_symbols, hparams.symbols_embedding_dim) std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim)) val = sqrt(3.0) * std # uniform bounds for std diff --git a/codes/models/audio/tts/tacotron2/wave_tacotron.py b/codes/models/audio/tts/tacotron2/wave_tacotron.py index 510f94c6..8e73e93d 100644 --- a/codes/models/audio/tts/tacotron2/wave_tacotron.py +++ b/codes/models/audio/tts/tacotron2/wave_tacotron.py @@ -13,7 +13,7 @@ from models.audio.tts.tacotron2.tacotron2 import Attention, Encoder from trainer.networks import register_model from models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths from utils.util import checkpoint -import bitsandbytes as bnb +import torch_intermediary as ml @@ -187,7 +187,7 @@ class WaveTacotron2(nn.Module): self.n_mel_channels = hparams.n_mel_channels self.n_frames_per_step = hparams.n_frames_per_step # nn.Embedding - self.embedding = bnb.nn.StableEmbedding( + self.embedding = ml.Embedding( hparams.n_symbols, hparams.symbols_embedding_dim) std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim)) val = sqrt(3.0) * std # uniform bounds for std diff --git a/codes/models/audio/tts/transformer_builders.py b/codes/models/audio/tts/transformer_builders.py index 74e7ceb2..ce88d1a7 100644 --- a/codes/models/audio/tts/transformer_builders.py +++ b/codes/models/audio/tts/transformer_builders.py @@ -25,7 +25,7 @@ import random from time import time import torch import torch.nn as nn -import bitsandbytes as bnb +import torch_intermediary as ml from tqdm import tqdm @@ -37,7 +37,7 @@ class LearnedPositionEmbeddings(nn.Module): def __init__(self, seq_len, model_dim, init=.02, relative=False): super().__init__() # nn.Embedding - self.emb = bnb.nn.StableEmbedding(seq_len, model_dim) + self.emb = ml.Embedding(seq_len, model_dim) # Initializing this way is standard for GPT-2 self.emb.weight.data.normal_(mean=0.0, std=init) self.relative = relative diff --git a/codes/models/audio/tts/transformer_diffusion_tts.py b/codes/models/audio/tts/transformer_diffusion_tts.py index 00b6485e..bb78a008 100644 --- a/codes/models/audio/tts/transformer_diffusion_tts.py +++ b/codes/models/audio/tts/transformer_diffusion_tts.py @@ -7,7 +7,7 @@ from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlo from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding from trainer.networks import register_model from utils.util import checkpoint -import bitsandbytes as bnb +import torch_intermediary as ml def is_latent(t): @@ -21,7 +21,7 @@ class MultiGroupEmbedding(nn.Module): def __init__(self, tokens, groups, dim): super().__init__() # nn.Embedding - self.m = nn.ModuleList([bnb.nn.StableEmbedding(tokens, dim // groups) for _ in range(groups)]) + self.m = nn.ModuleList([ml.Embedding(tokens, dim // groups) for _ in range(groups)]) def forward(self, x): h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)] @@ -102,9 +102,9 @@ class TransformerDiffusionTTS(nn.Module): ff_glu=True, rotary_pos_emb=True, ) - self.clvp_encoder = bnb.nn.Linear8bitLt(clvp_in_dim, model_channels) + self.clvp_encoder = ml.Linear(clvp_in_dim, model_channels) # nn.Embedding - self.type_embedding = bnb.nn.StableEmbedding(types, model_channels) + self.type_embedding = ml.Embedding(types, model_channels) # Either code_converter or latent_converter is used, depending on what type of conditioning data is fed. # This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally @@ -112,7 +112,7 @@ class TransformerDiffusionTTS(nn.Module): # transformer network. if in_groups is None: # nn.Embedding - self.embeddings = bnb.nn.StableEmbedding(token_count, model_channels) + self.embeddings = ml.Embedding(token_count, model_channels) else: self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels) self.latent_conditioner = nn.Sequential( @@ -144,7 +144,7 @@ class TransformerDiffusionTTS(nn.Module): self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1) self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) - self.intg = bnb.nn.Linear8bitLt(model_channels*2, model_channels) + self.intg = ml.Linear(model_channels*2, model_channels) self.layers = TimestepRotaryEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers)]) self.out = nn.Sequential( diff --git a/codes/models/audio/tts/transformer_diffusion_tts2.py b/codes/models/audio/tts/transformer_diffusion_tts2.py index cfe6b625..e4cb3a6e 100644 --- a/codes/models/audio/tts/transformer_diffusion_tts2.py +++ b/codes/models/audio/tts/transformer_diffusion_tts2.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -import bitsandbytes as bnb +import torch_intermediary as ml from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock @@ -21,7 +21,7 @@ class MultiGroupEmbedding(nn.Module): def __init__(self, tokens, groups, dim): super().__init__() # nn.Embedding - self.m = nn.ModuleList([bnb.nn.StableEmbedding(tokens, dim // groups) for _ in range(groups)]) + self.m = nn.ModuleList([ml.Embedding(tokens, dim // groups) for _ in range(groups)]) def forward(self, x): h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)] @@ -42,7 +42,7 @@ class DietAttentionBlock(TimestepBlock): def __init__(self, in_dim, dim, heads, dropout): super().__init__() self.rms_scale_norm = RMSScaleShiftNorm(in_dim) - self.proj = bnb.nn.Linear8bitLt(in_dim, dim) + self.proj = ml.Linear(in_dim, dim) self.attn = Attention(dim, heads=heads, causal=False, dropout=dropout) self.ff = FeedForward(dim, in_dim, mult=1, dropout=dropout, zero_init_output=True) @@ -107,9 +107,9 @@ class TransformerDiffusionTTS(nn.Module): ff_glu=True, rotary_pos_emb=True, ) - self.clvp_encoder = bnb.nn.Linear8bitLt(clvp_in_dim, prenet_channels) + self.clvp_encoder = ml.Linear(clvp_in_dim, prenet_channels) # nn.Embedding - self.type_embedding = bnb.nn.StableEmbedding(types, prenet_channels) + self.type_embedding = ml.Embedding(types, prenet_channels) # Either code_converter or latent_converter is used, depending on what type of conditioning data is fed. # This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally @@ -117,7 +117,7 @@ class TransformerDiffusionTTS(nn.Module): # transformer network. if in_groups is None: # nn.Embedding - self.embeddings = bnb.nn.StableEmbedding(token_count, prenet_channels) + self.embeddings = ml.Embedding(token_count, prenet_channels) else: self.embeddings = MultiGroupEmbedding(token_count, in_groups, prenet_channels) self.latent_conditioner = nn.Sequential( @@ -148,8 +148,8 @@ class TransformerDiffusionTTS(nn.Module): self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels)) self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) - self.cond_intg = bnb.nn.Linear8bitLt(prenet_channels*4, model_channels) - self.intg = bnb.nn.Linear8bitLt(prenet_channels*2, model_channels) + self.cond_intg = ml.Linear(prenet_channels*4, model_channels) + self.intg = ml.Linear(prenet_channels*2, model_channels) self.layers = TimestepRotaryEmbedSequential(*[DietAttentionBlock(model_channels, block_channels, block_channels // 64, dropout) for _ in range(num_layers)]) diff --git a/codes/models/audio/tts/unet_diffusion_tts7.py b/codes/models/audio/tts/unet_diffusion_tts7.py index ccfb4735..02d25cb3 100644 --- a/codes/models/audio/tts/unet_diffusion_tts7.py +++ b/codes/models/audio/tts/unet_diffusion_tts7.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch import autocast -import bitsandbytes as bnb +import torch_intermediary as ml from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \ @@ -249,7 +249,7 @@ class DiffusionTts(nn.Module): embedding_dim = model_channels * 8 # nn.Embedding - self.code_embedding = bnb.nn.StableEmbedding(num_tokens+1, embedding_dim) + self.code_embedding = ml.Embedding(num_tokens+1, embedding_dim) self.contextual_embedder = AudioMiniEncoder(1, embedding_dim, base_channels=32, depth=6, resnet_blocks=1, attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5) self.conditioning_conv = nn.Conv1d(embedding_dim*3, embedding_dim, 1) @@ -257,7 +257,7 @@ class DiffusionTts(nn.Module): self.enable_unaligned_inputs = enabled_unaligned_inputs if enabled_unaligned_inputs: # nn.Embedding - self.unaligned_embedder = bnb.nn.StableEmbedding(num_unaligned_tokens, embedding_dim) + self.unaligned_embedder = ml.Embedding(num_unaligned_tokens, embedding_dim) self.unaligned_encoder = CheckpointedXTransformerEncoder( max_seq_len=-1, use_pos_emb=False, diff --git a/codes/models/audio/tts/unet_diffusion_tts9.py b/codes/models/audio/tts/unet_diffusion_tts9.py index 341ff474..a00b2758 100644 --- a/codes/models/audio/tts/unet_diffusion_tts9.py +++ b/codes/models/audio/tts/unet_diffusion_tts9.py @@ -5,7 +5,7 @@ import torch.nn as nn import torch.nn.functional as F from torch import autocast from x_transformers import Encoder -import bitsandbytes as bnb +import torch_intermediary as ml from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \ @@ -208,7 +208,7 @@ class DiffusionTts(nn.Module): # transformer network. self.code_converter = nn.Sequential( # nn.Embedding - bnb.nn.StableEmbedding(in_tokens, conditioning_dim), + ml.Embedding(in_tokens, conditioning_dim), CheckpointedXTransformerEncoder( needs_permute=False, max_seq_len=-1, diff --git a/codes/models/audio/tts/unet_diffusion_tts_flat.py b/codes/models/audio/tts/unet_diffusion_tts_flat.py index 9ee14de4..34017011 100644 --- a/codes/models/audio/tts/unet_diffusion_tts_flat.py +++ b/codes/models/audio/tts/unet_diffusion_tts_flat.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch import autocast -import bitsandbytes as bnb +import torch_intermediary as ml from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock, QKVAttentionLegacy @@ -196,7 +196,7 @@ class DiffusionTtsFlat(nn.Module): # transformer network. # nn.Embedding - self.code_embedding = bnb.nn.StableEmbedding(in_tokens, model_channels) + self.code_embedding = ml.Embedding(in_tokens, model_channels) self.code_converter = nn.Sequential( AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), diff --git a/codes/models/audio/tts/unified_voice2.py b/codes/models/audio/tts/unified_voice2.py index b3382f3a..80e7a394 100644 --- a/codes/models/audio/tts/unified_voice2.py +++ b/codes/models/audio/tts/unified_voice2.py @@ -13,7 +13,7 @@ from models.lucidrains.x_transformers import RotaryEmbedding, apply_rotary_pos_e from trainer.networks import register_model from utils.util import opt_get -import bitsandbytes as bnb +import torch_intermediary as ml class ResBlock(nn.Module): """ @@ -282,10 +282,10 @@ class UnifiedVoice(nn.Module): self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) self.average_conditioning_embeddings = average_conditioning_embeddings # nn.Embedding - self.text_embedding = bnb.nn.StableEmbedding(self.number_text_tokens, model_dim) + self.text_embedding = ml.Embedding(self.number_text_tokens, model_dim) if use_mel_codes_as_input: # nn.Embedding - self.mel_embedding = bnb.nn.StableEmbedding(self.number_mel_codes, model_dim) + self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim) else: self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1) self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \ @@ -298,8 +298,8 @@ class UnifiedVoice(nn.Module): self.text_solo_embedding = 0 self.final_norm = nn.LayerNorm(model_dim) - self.text_head = bnb.nn.Linear8bitLt(model_dim, self.number_text_tokens) - self.mel_head = bnb.nn.Linear8bitLt(model_dim, self.number_mel_codes) + self.text_head = ml.Linear(model_dim, self.number_text_tokens) + self.mel_head = ml.Linear(model_dim, self.number_mel_codes) # Initialize the embeddings per the GPT-2 scheme embeddings = [self.text_embedding] diff --git a/codes/models/audio/tts/unified_voice3.py b/codes/models/audio/tts/unified_voice3.py index 1dfce7f6..49e2258c 100644 --- a/codes/models/audio/tts/unified_voice3.py +++ b/codes/models/audio/tts/unified_voice3.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -import bitsandbytes as bnb +import torch_intermediary as ml from transformers import GPT2Config, GPT2PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions @@ -274,16 +274,16 @@ class UnifiedVoice(nn.Module): self.mel_length_compression = mel_length_compression self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) # nn.Embedding - self.text_embedding = bnb.nn.StableEmbedding(self.number_text_tokens*types+1, model_dim) + self.text_embedding = ml.Embedding(self.number_text_tokens*types+1, model_dim) # nn.Embedding - self.mel_embedding = bnb.nn.StableEmbedding(self.number_mel_codes, model_dim) + self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim) self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \ build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing) self.final_norm = nn.LayerNorm(model_dim) - self.text_head = bnb.nn.Linear8bitLt(model_dim, self.number_text_tokens*types+1) - self.mel_head = bnb.nn.Linear8bitLt(model_dim, self.number_mel_codes) - self.aligned_head = bnb.nn.Linear8bitLt(model_dim, number_aligned_text_codes) + self.text_head = ml.Linear(model_dim, self.number_text_tokens*types+1) + self.mel_head = ml.Linear(model_dim, self.number_mel_codes) + self.aligned_head = ml.Linear(model_dim, number_aligned_text_codes) # Initialize the embeddings per the GPT-2 scheme embeddings = [self.text_embedding, self.mel_embedding] diff --git a/codes/models/audio/tts/unified_voice4.py b/codes/models/audio/tts/unified_voice4.py index d76d3cc3..9d8a8568 100644 --- a/codes/models/audio/tts/unified_voice4.py +++ b/codes/models/audio/tts/unified_voice4.py @@ -11,7 +11,7 @@ from models.audio.tts.transformer_builders import build_hf_gpt_transformer from models.lucidrains.x_transformers import RotaryEmbedding, apply_rotary_pos_emb from trainer.networks import register_model from utils.util import opt_get -import bitsandbytes as bnb +import torch_intermediary as ml class ResBlock(nn.Module): @@ -257,16 +257,16 @@ class UnifiedVoice(nn.Module): self.mel_length_compression = mel_length_compression self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) # nn.Embedding - self.text_embedding = bnb.nn.StableEmbedding(self.number_text_tokens*types+1, model_dim) + self.text_embedding = ml.Embedding(self.number_text_tokens*types+1, model_dim) # nn.Embedding - self.mel_embedding = bnb.nn.StableEmbedding(self.number_mel_codes, model_dim) + self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim) self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \ build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing) self.final_norm = nn.LayerNorm(model_dim) - self.text_head = bnb.nn.Linear8bitLt(model_dim, self.number_text_tokens*types+1) - self.mel_head = bnb.nn.Linear8bitLt(model_dim, self.number_mel_codes) - self.alignment_head = bnb.nn.Linear8bitLt(model_dim, 256) + self.text_head = ml.Linear(model_dim, self.number_text_tokens*types+1) + self.mel_head = ml.Linear(model_dim, self.number_mel_codes) + self.alignment_head = ml.Linear(model_dim, 256) if only_alignment_head: for p in self.parameters(): diff --git a/codes/models/audio/tts/voice_voice_clip.py b/codes/models/audio/tts/voice_voice_clip.py index 11401668..842f8b17 100644 --- a/codes/models/audio/tts/voice_voice_clip.py +++ b/codes/models/audio/tts/voice_voice_clip.py @@ -8,7 +8,7 @@ from models.audio.tts.mini_encoder import AudioMiniEncoder from trainer.injectors.spec_augment import spec_augment from trainer.networks import register_model from utils.util import opt_get -import bitsandbytes as bnb +import torch_intermediary as ml def exists(val): @@ -37,7 +37,7 @@ class VoiceCLIP(nn.Module): self.encoder = AudioMiniEncoder(80, encoder_output) if pretrained_encoder_dict_path is not None: self.encoder.load_state_dict(torch.load(pretrained_encoder_dict_path)) - self.to_latent = bnb.nn.Linear8bitLt(encoder_output, dim_latent, bias=False) + self.to_latent = ml.Linear(encoder_output, dim_latent, bias=False) self.temperature = nn.Parameter(torch.tensor(1.)) self.mel_compression_ratio = mel_compression_ratio diff --git a/codes/models/audio/tts/w2v_matcher.py b/codes/models/audio/tts/w2v_matcher.py index 61ac850a..636a1ca0 100644 --- a/codes/models/audio/tts/w2v_matcher.py +++ b/codes/models/audio/tts/w2v_matcher.py @@ -7,7 +7,7 @@ from x_transformers import Encoder, Decoder, ContinuousTransformerWrapper from models.audio.tts.mini_encoder import AudioMiniEncoder from trainer.networks import register_model -import bitsandbytes as bnb +import torch_intermediary as ml class CheckpointedLayer(nn.Module): @@ -58,7 +58,7 @@ class Wav2VecMatcher(nn.Module): self.conditioning_encoder = AudioMiniEncoder(1, model_dim, base_channels=32, depth=6, resnet_blocks=1, attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5) # nn.Embedding - self.text_embedding = bnb.nn.StableEmbedding(num_text_tokens, model_dim) + self.text_embedding = ml.Embedding(num_text_tokens, model_dim) self.encoder = CheckpointedXTransformer( max_seq_len=-1, use_pos_emb=False, @@ -75,8 +75,8 @@ class Wav2VecMatcher(nn.Module): ) self.decoder_start_embedding = nn.Parameter(torch.randn(1,1,model_dim)) self.decoder_stop_embedding = nn.Parameter(torch.randn(1,model_dim)) - self.w2v_query_encoder = bnb.nn.Linear8bitLt(WAV2VEC_CHANNELS, model_dim) - self.w2v_value_encoder = bnb.nn.Linear8bitLt(WAV2VEC_CHANNELS, model_dim) + self.w2v_query_encoder = ml.Linear(WAV2VEC_CHANNELS, model_dim) + self.w2v_value_encoder = ml.Linear(WAV2VEC_CHANNELS, model_dim) self.decoder = CheckpointedXTransformer( max_seq_len=-1, # Should be unused use_pos_emb=False, diff --git a/codes/models/classifiers/cifar_resnet.py b/codes/models/classifiers/cifar_resnet.py index d4efbdf4..86a7af0e 100644 --- a/codes/models/classifiers/cifar_resnet.py +++ b/codes/models/classifiers/cifar_resnet.py @@ -10,7 +10,7 @@ import torch import torch.nn as nn -import bitsandbytes as bnb +import torch_intermediary as ml from trainer.networks import register_model @@ -99,7 +99,7 @@ class ResNet(nn.Module): self.conv4_x = self._make_layer(block, 128, num_block[2], 2) self.conv5_x = self._make_layer(block, 256, num_block[3], 2) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - self.fc = bnb.nn.Linear8bitLt(256 * block.expansion, num_classes) + self.fc = ml.Linear(256 * block.expansion, num_classes) def _make_layer(self, block, out_channels, num_blocks, stride): """make resnet layers(by layer i didnt mean this 'layer' was the diff --git a/codes/models/classifiers/resnet_with_checkpointing.py b/codes/models/classifiers/resnet_with_checkpointing.py index 526edcd9..f6f6c5e7 100644 --- a/codes/models/classifiers/resnet_with_checkpointing.py +++ b/codes/models/classifiers/resnet_with_checkpointing.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn from torchvision.models.resnet import BasicBlock, Bottleneck import torchvision -import bitsandbytes as bnb +import torch_intermediary as ml __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', @@ -195,5 +195,5 @@ def wide_resnet101_2(pretrained=False, progress=True, **kwargs): def register_resnet50(opt_net, opt): model = resnet50(pretrained=opt_net['pretrained']) if opt_net['custom_head_logits']: - model.fc = bnb.nn.Linear8bitLt(512 * 4, opt_net['custom_head_logits']) + model.fc = ml.Linear(512 * 4, opt_net['custom_head_logits']) return model diff --git a/codes/models/classifiers/twin_cifar_resnet.py b/codes/models/classifiers/twin_cifar_resnet.py index 2025f65a..bf4d1214 100644 --- a/codes/models/classifiers/twin_cifar_resnet.py +++ b/codes/models/classifiers/twin_cifar_resnet.py @@ -11,7 +11,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -import bitsandbytes as bnb +import torch_intermediary as ml from trainer.networks import register_model @@ -102,7 +102,7 @@ class ResNet(nn.Module): self.conv4_x = self._make_layer(block, 128, num_block[2], 2) self.conv5_x = self._make_layer(block, 256, num_block[3], 2) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - self.fc = bnb.nn.Linear8bitLt(256 * block.expansion, num_classes) + self.fc = ml.Linear(256 * block.expansion, num_classes) def _make_layer(self, block, out_channels, num_blocks, stride): """make resnet layers(by layer i didnt mean this 'layer' was the diff --git a/codes/models/classifiers/weighted_conv_resnet.py b/codes/models/classifiers/weighted_conv_resnet.py index dfbb6724..5b5beaf6 100644 --- a/codes/models/classifiers/weighted_conv_resnet.py +++ b/codes/models/classifiers/weighted_conv_resnet.py @@ -11,7 +11,7 @@ __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', from models.vqvae.scaled_weight_conv import ScaledWeightConv from trainer.networks import register_model from utils.util import checkpoint -import bitsandbytes as bnb +import torch_intermediary as ml model_urls = { 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', @@ -214,7 +214,7 @@ class ResNet(nn.Module): self.layer4 = self._make_layer(block, 512, layers[3], breadth, stride=2, dilate=replace_stride_with_dilation[2]) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - self.fc = bnb.nn.Linear8bitLt(512 * block.expansion, num_classes) + self.fc = ml.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, ScaledWeightConv): diff --git a/codes/models/classifiers/wide_kernel_vgg.py b/codes/models/classifiers/wide_kernel_vgg.py index 0bb02f7b..ab26fb0f 100644 --- a/codes/models/classifiers/wide_kernel_vgg.py +++ b/codes/models/classifiers/wide_kernel_vgg.py @@ -3,7 +3,7 @@ import torch.nn as nn from trainer.networks import register_model from utils.util import opt_get -import bitsandbytes as bnb +import torch_intermediary as ml class WideKernelVgg(nn.Module): def __init__(self, nf=64, num_classes=2): @@ -49,9 +49,9 @@ class WideKernelVgg(nn.Module): nn.ReLU(), nn.MaxPool2d(kernel_size=2), nn.Flatten(), - bnb.nn.Linear8bitLt(nf * 8 * 4 * 2, 100), + ml.Linear(nf * 8 * 4 * 2, 100), nn.ReLU(), - bnb.nn.Linear8bitLt(100, num_classes) + ml.Linear(100, num_classes) ) # These normalization constants should be derived experimentally. diff --git a/codes/models/clip/clvp.py b/codes/models/clip/clvp.py index c22da91f..9bfc7655 100644 --- a/codes/models/clip/clvp.py +++ b/codes/models/clip/clvp.py @@ -10,7 +10,7 @@ from models.arch_util import AttentionBlock from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder from trainer.networks import register_model from utils.util import opt_get, checkpoint -import bitsandbytes as bnb +import torch_intermediary as ml def exists(val): @@ -60,7 +60,7 @@ class ConvFormatEmbedding(nn.Module): def __init__(self, *args, **kwargs): super().__init__() # nn.Embedding - self.emb = bnb.nn.StableEmbedding(*args, **kwargs) + self.emb = ml.Embedding(*args, **kwargs) def forward(self, x): y = self.emb(x) @@ -101,9 +101,9 @@ class CLVP(nn.Module): self.mask_conditioning_percentage = mask_conditioning_percentage # nn.Embedding - self.text_emb = bnb.nn.StableEmbedding(num_text_tokens, model_dim) + self.text_emb = ml.Embedding(num_text_tokens, model_dim) self.text_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, text_enc_depth, text_mask_percentage, use_rms_scaleshift_norm=True) - self.to_text_latent = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False) + self.to_text_latent = ml.Linear(latent_dim, latent_dim, bias=False) self.distributed_collect = distributed_collect if mel_codes is None: @@ -111,7 +111,7 @@ class CLVP(nn.Module): else: self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim) self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage) - self.to_speech_latent = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False) + self.to_speech_latent = ml.Linear(latent_dim, latent_dim, bias=False) def get_grad_norm_parameter_groups(self): return { diff --git a/codes/models/clip/contrastive_audio.py b/codes/models/clip/contrastive_audio.py index ce53ddce..85edfec6 100644 --- a/codes/models/clip/contrastive_audio.py +++ b/codes/models/clip/contrastive_audio.py @@ -9,7 +9,7 @@ from models.arch_util import AttentionBlock from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder from trainer.networks import register_model from utils.util import opt_get, checkpoint -import bitsandbytes as bnb +import torch_intermediary as ml def exists(val): @@ -180,7 +180,7 @@ class ConvFormatEmbedding(nn.Module): def __init__(self, *args, **kwargs): super().__init__() # nn.Embedding - self.emb = bnb.nn.StableEmbedding(*args, **kwargs) + self.emb = ml.Embedding(*args, **kwargs) def forward(self, x): y = self.emb(x) @@ -205,8 +205,8 @@ class ContrastiveAudio(nn.Module): self.emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim // 2, kernel_size=5, stride=2, padding=2), nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1)) self.transformer = CollapsingTransformer(model_dim, model_dim, transformer_heads, dropout, encoder_depth, mask_percent) - self.to_latent = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False) - self.to_latent2 = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False) + self.to_latent = ml.Linear(latent_dim, latent_dim, bias=False) + self.to_latent2 = ml.Linear(latent_dim, latent_dim, bias=False) self.to_latent2.weight.data = self.to_latent.weight.data self.to_latent2.weight.DO_NOT_TRAIN = True diff --git a/codes/models/clip/cvvp.py b/codes/models/clip/cvvp.py index 24567265..e8a12b2a 100644 --- a/codes/models/clip/cvvp.py +++ b/codes/models/clip/cvvp.py @@ -10,7 +10,7 @@ from models.arch_util import AttentionBlock from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder from trainer.networks import register_model from utils.util import opt_get, checkpoint -import bitsandbytes as bnb +import torch_intermediary as ml def exists(val): @@ -60,7 +60,7 @@ class ConvFormatEmbedding(nn.Module): def __init__(self, *args, **kwargs): super().__init__() # nn.Embedding - self.emb = bnb.nn.StableEmbedding(*args, **kwargs) + self.emb = ml.Embedding(*args, **kwargs) def forward(self, x): y = self.emb(x) @@ -88,14 +88,14 @@ class CVVP(nn.Module): self.cond_emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim//2, kernel_size=5, stride=2, padding=2), nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1)) self.conditioning_transformer = CollapsingTransformer(model_dim, model_dim, transformer_heads, dropout, conditioning_enc_depth, cond_mask_percentage) - self.to_conditioning_latent = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False) + self.to_conditioning_latent = ml.Linear(latent_dim, latent_dim, bias=False) if mel_codes is None: self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2) else: self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim) self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage) - self.to_speech_latent = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False) + self.to_speech_latent = ml.Linear(latent_dim, latent_dim, bias=False) def get_grad_norm_parameter_groups(self): return { diff --git a/codes/models/clip/mel_text_clip.py b/codes/models/clip/mel_text_clip.py index 6d55696a..c053547d 100644 --- a/codes/models/clip/mel_text_clip.py +++ b/codes/models/clip/mel_text_clip.py @@ -7,7 +7,7 @@ from torch import einsum from models.lucidrains.dalle.transformer import Transformer from trainer.networks import register_model from utils.util import opt_get -import bitsandbytes as bnb +import torch_intermediary as ml def exists(val): @@ -47,19 +47,19 @@ class MelTextCLIP(nn.Module): ): super().__init__() # nn.Embedding - self.text_emb = bnb.nn.StableEmbedding(num_text_tokens, dim_text) + self.text_emb = ml.Embedding(num_text_tokens, dim_text) # nn.Embedding - self.text_pos_emb = bnb.nn.StableEmbedding(text_seq_len, dim_text) + self.text_pos_emb = ml.Embedding(text_seq_len, dim_text) self.text_transformer = Transformer(causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth, heads=text_heads, rotary_emb=False) - self.to_text_latent = bnb.nn.Linear8bitLt(dim_text, dim_latent, bias=False) + self.to_text_latent = ml.Linear(dim_text, dim_latent, bias=False) self.speech_enc = nn.Conv1d(80, dim_speech, kernel_size=3, padding=1) # nn.Embedding - self.speech_pos_emb = bnb.nn.StableEmbedding(num_speech_tokens, dim_speech) + self.speech_pos_emb = ml.Embedding(num_speech_tokens, dim_speech) self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech, depth=speech_enc_depth, heads=speech_heads, rotary_emb=False) - self.to_speech_latent = bnb.nn.Linear8bitLt(dim_speech, dim_latent, bias=False) + self.to_speech_latent = ml.Linear(dim_speech, dim_latent, bias=False) self.temperature = nn.Parameter(torch.tensor(1.)) self.text_mask_percentage = text_mask_percentage diff --git a/codes/models/clip/text_cond_clip.py b/codes/models/clip/text_cond_clip.py index 12ebdbc3..f221142c 100644 --- a/codes/models/clip/text_cond_clip.py +++ b/codes/models/clip/text_cond_clip.py @@ -7,7 +7,7 @@ from models.audio.tts.unified_voice2 import ConditioningEncoder from models.lucidrains.dalle.transformer import Transformer from trainer.networks import register_model from utils.util import opt_get -import bitsandbytes as bnb +import torch_intermediary as ml def exists(val): @@ -46,7 +46,7 @@ class VoiceCondCLIP(nn.Module): self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech) self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech, depth=speech_enc_depth, heads=speech_heads, rotary_emb=False) - self.to_speech_latent = bnb.nn.Linear8bitLt(dim_speech, dim_latent, bias=False) + self.to_speech_latent = ml.Linear(dim_speech, dim_latent, bias=False) self.temperature = nn.Parameter(torch.tensor(1.)) self.voice_mask_percentage = voice_mask_percentage diff --git a/codes/models/clip/text_voice_clip.py b/codes/models/clip/text_voice_clip.py index 0fb1670b..26bac0c9 100644 --- a/codes/models/clip/text_voice_clip.py +++ b/codes/models/clip/text_voice_clip.py @@ -11,7 +11,7 @@ from models.audio.tts.unet_diffusion_tts7 import CheckpointedXTransformerEncoder from models.lucidrains.dalle.transformer import Transformer from trainer.networks import register_model from utils.util import opt_get -import bitsandbytes as bnb +import torch_intermediary as ml def exists(val): @@ -55,12 +55,12 @@ class VoiceCLIP(nn.Module): ): super().__init__() # nn.Embedding - self.text_emb = bnb.nn.StableEmbedding(num_text_tokens, dim_text) - self.to_text_latent = bnb.nn.Linear8bitLt(dim_text, dim_latent, bias=False) + self.text_emb = ml.Embedding(num_text_tokens, dim_text) + self.to_text_latent = ml.Linear(dim_text, dim_latent, bias=False) # nn.Embedding - self.speech_emb = bnb.nn.StableEmbedding(num_speech_tokens, dim_speech) - self.to_speech_latent = bnb.nn.Linear8bitLt(dim_speech, dim_latent, bias=False) + self.speech_emb = ml.Embedding(num_speech_tokens, dim_speech) + self.to_speech_latent = ml.Linear(dim_speech, dim_latent, bias=False) if use_xformers: self.text_transformer = CheckpointedXTransformerEncoder( @@ -109,9 +109,9 @@ class VoiceCLIP(nn.Module): self.distributed_collect = distributed_collect if not use_xformers: # nn.Embedding - self.text_pos_emb = bnb.nn.StableEmbedding(text_seq_len, dim_text) + self.text_pos_emb = ml.Embedding(text_seq_len, dim_text) # nn.Embedding - self.speech_pos_emb = bnb.nn.StableEmbedding(num_speech_tokens, dim_speech) + self.speech_pos_emb = ml.Embedding(num_speech_tokens, dim_speech) def embed_text(self, text): text_mask = torch.ones_like(text.float()).bool() diff --git a/codes/models/diffusion/nn.py b/codes/models/diffusion/nn.py index a40343b7..201fadb9 100644 --- a/codes/models/diffusion/nn.py +++ b/codes/models/diffusion/nn.py @@ -6,7 +6,7 @@ import math import torch as th import torch.nn as nn -import bitsandbytes as bnb +import torch_intermediary as ml # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. @@ -37,7 +37,7 @@ def linear(*args, **kwargs): """ Create a linear module. """ - return bnb.nn.Linear8bitLt(*args, **kwargs) + return ml.Linear(*args, **kwargs) def avg_pool_nd(dims, *args, **kwargs): diff --git a/codes/models/diffusion/rrdb_diffusion.py b/codes/models/diffusion/rrdb_diffusion.py index 0ca91b5b..933fe467 100644 --- a/codes/models/diffusion/rrdb_diffusion.py +++ b/codes/models/diffusion/rrdb_diffusion.py @@ -6,7 +6,7 @@ from models.arch_util import ConvGnLelu, default_init_weights, make_layer from models.diffusion.nn import timestep_embedding from trainer.networks import register_model from utils.util import checkpoint -import bitsandbytes as bnb +import torch_intermediary as ml # Conditionally uses torch's checkpoint functionality if it is enabled in the opt file. @@ -29,7 +29,7 @@ class ResidualDenseBlock(nn.Module): self.first_conv = ConvGnLelu(mid_channels, mid_channels, activation=True, norm=False, bias=True) self.emb_layers = nn.Sequential( nn.SiLU(), - bnb.nn.Linear8bitLt( + ml.Linear( mid_channels*4, mid_channels, ), @@ -144,9 +144,9 @@ class RRDBNet(nn.Module): # Guided diffusion uses a time embedding. time_embed_dim = mid_channels * 4 self.time_embed = nn.Sequential( - bnb.nn.Linear8bitLt(mid_channels, time_embed_dim), + ml.Linear(mid_channels, time_embed_dim), nn.SiLU(), - bnb.nn.Linear8bitLt(time_embed_dim, time_embed_dim), + ml.Linear(time_embed_dim, time_embed_dim), ) self.body = make_layer( diff --git a/codes/models/diffusion/unet_diffusion.py b/codes/models/diffusion/unet_diffusion.py index ea5ec504..6ebbae02 100644 --- a/codes/models/diffusion/unet_diffusion.py +++ b/codes/models/diffusion/unet_diffusion.py @@ -20,7 +20,7 @@ from models.diffusion.nn import ( ) from trainer.networks import register_model from utils.util import checkpoint -import bitsandbytes as bnb +import torch_intermediary as ml class AttentionPool2d(nn.Module): @@ -517,7 +517,7 @@ class UNetModel(nn.Module): if self.num_classes is not None: # nn.Embedding - self.label_emb = bnb.nn.StableEmbedding(num_classes, time_embed_dim) + self.label_emb = ml.Embedding(num_classes, time_embed_dim) self.use_raw_y_as_embedding = use_raw_y_as_embedding assert not ((self.num_classes is not None) and use_raw_y_as_embedding) # These are mutually-exclusive. @@ -869,16 +869,16 @@ class EncoderUNetModel(nn.Module): ) elif pool == "spatial": self.out = nn.Sequential( - bnb.nn.Linear8bitLt(self._feature_size, 2048), + ml.Linear(self._feature_size, 2048), nn.ReLU(), - bnb.nn.Linear8bitLt(2048, self.out_channels), + ml.Linear(2048, self.out_channels), ) elif pool == "spatial_v2": self.out = nn.Sequential( - bnb.nn.Linear8bitLt(self._feature_size, 2048), + ml.Linear(self._feature_size, 2048), normalization(2048), nn.SiLU(), - bnb.nn.Linear8bitLt(2048, self.out_channels), + ml.Linear(2048, self.out_channels), ) else: raise NotImplementedError(f"Unexpected {pool} pooling") diff --git a/codes/models/diffusion/unet_latent_guide.py b/codes/models/diffusion/unet_latent_guide.py index 52272141..e298a900 100644 --- a/codes/models/diffusion/unet_latent_guide.py +++ b/codes/models/diffusion/unet_latent_guide.py @@ -26,7 +26,7 @@ from models.diffusion.nn import ( ) from trainer.networks import register_model from utils.util import checkpoint -import bitsandbytes as bnb +import torch_intermediary as ml class AttentionPool2d(nn.Module): @@ -478,7 +478,7 @@ class UNetModel(nn.Module): if self.num_classes is not None: # nn.Embedding - self.label_emb = bnb.nn.StableEmbedding(num_classes, time_embed_dim) + self.label_emb = ml.Embedding(num_classes, time_embed_dim) self.input_blocks = nn.ModuleList( [ @@ -738,7 +738,7 @@ class ResNetEncoder(nn.Module): dilate=replace_stride_with_dilation[2]) f=512 self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - self.fc = bnb.nn.Linear8bitLt(f * block.expansion, output_dim) + self.fc = ml.Linear(f * block.expansion, output_dim) for m in self.modules(): if isinstance(m, nn.Conv2d): diff --git a/codes/models/image_generation/discriminator_vgg_arch.py b/codes/models/image_generation/discriminator_vgg_arch.py index 05ea3e8b..af44eca2 100644 --- a/codes/models/image_generation/discriminator_vgg_arch.py +++ b/codes/models/image_generation/discriminator_vgg_arch.py @@ -5,7 +5,7 @@ import torch.nn.functional as F from trainer.networks import register_model from utils.util import checkpoint, opt_get -import bitsandbytes as bnb +import torch_intermediary as ml class Discriminator_VGG_128(nn.Module): @@ -47,8 +47,8 @@ class Discriminator_VGG_128(nn.Module): input_img_factor = input_img_factor // 2 final_nf = nf * 16 - self.linear1 = bnb.nn.Linear8bitLt(final_nf * 4 * input_img_factor * 4 * input_img_factor, 100) - self.linear2 = bnb.nn.Linear8bitLt(100, 1) + self.linear1 = ml.Linear(final_nf * 4 * input_img_factor * 4 * input_img_factor, 100) + self.linear2 = ml.Linear(100, 1) # activation function self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) @@ -130,8 +130,8 @@ class Discriminator_VGG_128_GN(nn.Module): # activation function self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - self.linear1 = bnb.nn.Linear8bitLt(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 100) - self.linear2 = bnb.nn.Linear8bitLt(100, 1) + self.linear1 = ml.Linear(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 100) + self.linear2 = ml.Linear(100, 1) def compute_body(self, x): fea = self.lrelu(self.conv0_0(x)) @@ -220,8 +220,8 @@ class DiscriminatorVGG448GN(nn.Module): self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) final_nf = nf * 8 - self.linear1 = bnb.nn.Linear8bitLt(int(final_nf * 7 * 7), 100) - self.linear2 = bnb.nn.Linear8bitLt(100, 1) + self.linear1 = ml.Linear(int(final_nf * 7 * 7), 100) + self.linear2 = ml.Linear(100, 1) # Assign all new heads to the new param group.2 for m in [self.convn1_0, self.convn1_1, self.bnn1_1, self.conv0_0_new, self.bn0_0]: diff --git a/codes/models/image_generation/srflow/module_util.py b/codes/models/image_generation/srflow/module_util.py index 032e3af3..f50198dd 100644 --- a/codes/models/image_generation/srflow/module_util.py +++ b/codes/models/image_generation/srflow/module_util.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn import torch.nn.init as init import torch.nn.functional as F -import bitsandbytes as bnb +import torch_intermediary as ml def initialize_weights(net_l, scale=1): @@ -15,7 +15,7 @@ def initialize_weights(net_l, scale=1): m.weight.data *= scale # for residual block if m.bias is not None: m.bias.data.zero_() - elif isinstance(m, bnb.nn.Linear8bitLt): + elif isinstance(m, ml.Linear): init.kaiming_normal_(m.weight, a=0, mode='fan_in') m.weight.data *= scale if m.bias is not None: diff --git a/codes/models/image_generation/stylegan/stylegan2_lucidrains.py b/codes/models/image_generation/stylegan/stylegan2_lucidrains.py index bb72d90c..6016fb9e 100644 --- a/codes/models/image_generation/stylegan/stylegan2_lucidrains.py +++ b/codes/models/image_generation/stylegan/stylegan2_lucidrains.py @@ -28,7 +28,7 @@ except: APEX_AVAILABLE = False assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.' -import bitsandbytes as bnb +import torch_intermediary as ml num_cores = multiprocessing.cpu_count() @@ -352,7 +352,7 @@ class RGBBlock(nn.Module): def __init__(self, latent_dim, input_channel, upsample, rgba=False): super().__init__() self.input_channel = input_channel - self.to_style = bnb.nn.Linear8bitLt(latent_dim, input_channel) + self.to_style = ml.Linear(latent_dim, input_channel) out_filters = 3 if not rgba else 4 self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False) @@ -490,16 +490,16 @@ class GeneratorBlockWithStructure(nn.Module): # Uses stylegan1 style blocks for injecting structural latent. self.conv0 = EqualConv2d(input_channels, filters, 3, padding=1) - self.to_noise0 = bnb.nn.Linear8bitLt(1, filters) + self.to_noise0 = ml.Linear(1, filters) self.noise0 = equal_lr(NoiseInjection(filters)) self.adain0 = AdaptiveInstanceNorm(filters, latent_dim) - self.to_style1 = bnb.nn.Linear8bitLt(latent_dim, filters) - self.to_noise1 = bnb.nn.Linear8bitLt(1, filters) + self.to_style1 = ml.Linear(latent_dim, filters) + self.to_noise1 = ml.Linear(1, filters) self.conv1 = Conv2DMod(filters, filters, 3) - self.to_style2 = bnb.nn.Linear8bitLt(latent_dim, filters) - self.to_noise2 = bnb.nn.Linear8bitLt(1, filters) + self.to_style2 = ml.Linear(latent_dim, filters) + self.to_noise2 = ml.Linear(1, filters) self.conv2 = Conv2DMod(filters, filters, 3) self.activation = leaky_relu() @@ -541,12 +541,12 @@ class GeneratorBlock(nn.Module): self.structure_conv = nn.Conv2d(3, input_channels, 3, padding=1) input_channels = input_channels * 2 - self.to_style1 = bnb.nn.Linear8bitLt(latent_dim, input_channels) - self.to_noise1 = bnb.nn.Linear8bitLt(1, filters) + self.to_style1 = ml.Linear(latent_dim, input_channels) + self.to_noise1 = ml.Linear(1, filters) self.conv1 = Conv2DMod(input_channels, filters, 3) - self.to_style2 = bnb.nn.Linear8bitLt(latent_dim, filters) - self.to_noise2 = bnb.nn.Linear8bitLt(1, filters) + self.to_style2 = ml.Linear(latent_dim, filters) + self.to_noise2 = ml.Linear(1, filters) self.conv2 = Conv2DMod(filters, filters, 3) self.activation = leaky_relu() @@ -725,7 +725,7 @@ class StyleGan2GeneratorWithLatent(nn.Module): def _init_weights(self): for m in self.modules(): - if type(m) in {nn.Conv2d, bnb.nn.Linear8bitLt} and hasattr(m, 'weight'): + if type(m) in {nn.Conv2d, ml.Linear} and hasattr(m, 'weight'): nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') for block in self.gen.blocks: @@ -805,7 +805,7 @@ class StyleGan2Discriminator(nn.Module): self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1) self.flatten = Flatten() - self.to_logit = bnb.nn.Linear8bitLt(latent_dim, 1) + self.to_logit = ml.Linear(latent_dim, 1) self._init_weights() @@ -837,7 +837,7 @@ class StyleGan2Discriminator(nn.Module): def _init_weights(self): for m in self.modules(): - if type(m) in {nn.Conv2d, bnb.nn.Linear8bitLt}: + if type(m) in {nn.Conv2d, ml.Linear}: nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') diff --git a/codes/models/image_latents/byol/byol_model_wrapper.py b/codes/models/image_latents/byol/byol_model_wrapper.py index 105d7ec7..6c8bb3e3 100644 --- a/codes/models/image_latents/byol/byol_model_wrapper.py +++ b/codes/models/image_latents/byol/byol_model_wrapper.py @@ -12,7 +12,7 @@ from torch import nn from data.images.byol_attachment import RandomApply from trainer.networks import register_model, create_model from utils.util import checkpoint, opt_get -import bitsandbytes as bnb +import torch_intermediary as ml def default(val, def_val): @@ -79,10 +79,10 @@ class MLP(nn.Module): def __init__(self, dim, projection_size, hidden_size=4096): super().__init__() self.net = nn.Sequential( - bnb.nn.Linear8bitLt(dim, hidden_size), + ml.Linear(dim, hidden_size), nn.BatchNorm1d(hidden_size), nn.ReLU(inplace=True), - bnb.nn.Linear8bitLt(hidden_size, projection_size) + ml.Linear(hidden_size, projection_size) ) def forward(self, x): @@ -104,10 +104,10 @@ class StructuralMLP(nn.Module): nn.BatchNorm2d(c), nn.ReLU(inplace=True), nn.Flatten(), - bnb.nn.Linear8bitLt(flattened_dim, hidden_size), + ml.Linear(flattened_dim, hidden_size), nn.BatchNorm1d(hidden_size), nn.ReLU(inplace=True), - bnb.nn.Linear8bitLt(hidden_size, projection_size) + ml.Linear(hidden_size, projection_size) ) def forward(self, x): diff --git a/codes/models/image_latents/fixup_resnet/DiscriminatorResnet_arch.py b/codes/models/image_latents/fixup_resnet/DiscriminatorResnet_arch.py index 2513ba94..2c2a8fcd 100644 --- a/codes/models/image_latents/fixup_resnet/DiscriminatorResnet_arch.py +++ b/codes/models/image_latents/fixup_resnet/DiscriminatorResnet_arch.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import numpy as np -import bitsandbytes as bnb +import torch_intermediary as ml __all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152'] @@ -109,8 +109,8 @@ class FixupResNet(nn.Module): self.layer4 = self._make_layer(block, num_filters*8, layers[3], stride=2) self.bias2 = nn.Parameter(torch.zeros(1)) reduced_img_sz = int(input_img_size / 32) - self.fc1 = bnb.nn.Linear8bitLt(num_filters * 8 * reduced_img_sz * reduced_img_sz, 100) - self.fc2 = bnb.nn.Linear8bitLt(100, num_classes) + self.fc1 = ml.Linear(num_filters * 8 * reduced_img_sz * reduced_img_sz, 100) + self.fc2 = ml.Linear(100, num_classes) for m in self.modules(): if isinstance(m, FixupBasicBlock): @@ -125,7 +125,7 @@ class FixupResNet(nn.Module): if m.downsample is not None: nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:])))) ''' - elif isinstance(m, bnb.nn.Linear8bitLt): + elif isinstance(m, ml.Linear): nn.init.constant_(m.weight, 0) nn.init.constant_(m.bias, 0)''' diff --git a/codes/models/image_latents/vit_latent.py b/codes/models/image_latents/vit_latent.py index 3f45a5c8..b38671be 100644 --- a/codes/models/image_latents/vit_latent.py +++ b/codes/models/image_latents/vit_latent.py @@ -5,7 +5,7 @@ import torch.nn.functional as F from models.arch_util import ResBlock from models.lucidrains.x_transformers import Encoder from trainer.networks import register_model -import bitsandbytes as bnb +import torch_intermediary as ml class VitLatent(nn.Module): @@ -32,10 +32,10 @@ class VitLatent(nn.Module): do_checkpointing=True ) - self.mlp = nn.Sequential(bnb.nn.Linear8bitLt(hidden_dim, hidden_dim*2), + self.mlp = nn.Sequential(ml.Linear(hidden_dim, hidden_dim*2), nn.BatchNorm1d(hidden_dim*2), nn.ReLU(inplace=True), - bnb.nn.Linear8bitLt(hidden_dim*2, hidden_dim)) + ml.Linear(hidden_dim*2, hidden_dim)) def provide_ema(self, ema): self.ema = ema diff --git a/codes/models/lucidrains/dalle/attention.py b/codes/models/lucidrains/dalle/attention.py index 3662dbcb..09ec3978 100644 --- a/codes/models/lucidrains/dalle/attention.py +++ b/codes/models/lucidrains/dalle/attention.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from einops import rearrange, repeat from rotary_embedding_torch import apply_rotary_emb -import bitsandbytes as bnb +import torch_intermediary as ml # helpers @@ -48,9 +48,9 @@ class Attention(nn.Module): self.stable = stable self.causal = causal - self.to_qkv = bnb.nn.Linear8bitLt(dim, inner_dim * 3, bias = False) + self.to_qkv = ml.Linear(dim, inner_dim * 3, bias = False) self.to_out = nn.Sequential( - bnb.nn.Linear8bitLt(inner_dim, dim), + ml.Linear(inner_dim, dim), nn.Dropout(dropout) ) @@ -103,10 +103,10 @@ class SparseConvCausalAttention(nn.Module): self.stable = stable - self.to_qkv = bnb.nn.Linear8bitLt(dim, inner_dim * 3, bias = False) + self.to_qkv = ml.Linear(dim, inner_dim * 3, bias = False) self.to_out = nn.Sequential( - bnb.nn.Linear8bitLt(inner_dim, dim), + ml.Linear(inner_dim, dim), nn.Dropout(dropout) ) @@ -223,10 +223,10 @@ class SparseAxialCausalAttention(nn.Module): self.stable = stable - self.to_qkv = bnb.nn.Linear8bitLt(dim, inner_dim * 3, bias = False) + self.to_qkv = ml.Linear(dim, inner_dim * 3, bias = False) self.to_out = nn.Sequential( - bnb.nn.Linear8bitLt(inner_dim, dim), + ml.Linear(inner_dim, dim), nn.Dropout(dropout) ) diff --git a/codes/models/lucidrains/dalle/transformer.py b/codes/models/lucidrains/dalle/transformer.py index e3ad4769..357ce1ad 100644 --- a/codes/models/lucidrains/dalle/transformer.py +++ b/codes/models/lucidrains/dalle/transformer.py @@ -11,7 +11,7 @@ from models.lucidrains.dalle.attention import Attention, SparseAttention, Sparse from rotary_embedding_torch import RotaryEmbedding, broadcat from g_mlp_pytorch import gMLPBlock -import bitsandbytes as bnb +import torch_intermediary as ml # helpers @@ -79,10 +79,10 @@ class FeedForward(nn.Module): def __init__(self, dim, dropout = 0., mult = 4.): super().__init__() self.net = nn.Sequential( - bnb.nn.Linear8bitLt(dim, dim * mult * 2), + ml.Linear(dim, dim * mult * 2), GEGLU(), nn.Dropout(dropout), - bnb.nn.Linear8bitLt(dim * mult, dim) + ml.Linear(dim * mult, dim) ) def forward(self, x): diff --git a/codes/models/lucidrains/performer/performer_pytorch.py b/codes/models/lucidrains/performer/performer_pytorch.py index db46663d..98ce769f 100644 --- a/codes/models/lucidrains/performer/performer_pytorch.py +++ b/codes/models/lucidrains/performer/performer_pytorch.py @@ -21,7 +21,7 @@ try: APEX_AVAILABLE = True except: APEX_AVAILABLE = False -import bitsandbytes as bnb +import torch_intermediary as ml # helpers @@ -357,10 +357,10 @@ class FeedForward(nn.Module): activation = default(activation, nn.GELU) self.glu = glu - self.w1 = bnb.nn.Linear8bitLt(dim, dim * mult * (2 if glu else 1)) + self.w1 = ml.Linear(dim, dim * mult * (2 if glu else 1)) self.act = activation() self.dropout = nn.Dropout(dropout) - self.w2 = bnb.nn.Linear8bitLt(dim * mult, dim) + self.w2 = ml.Linear(dim * mult, dim) def forward(self, x, **kwargs): if not self.glu: @@ -402,10 +402,10 @@ class Attention(nn.Module): self.global_heads = heads - local_heads self.local_attn = LocalAttention(window_size = local_window_size, causal = causal, autopad = True, dropout = dropout, look_forward = int(not causal), rel_pos_emb_config = (dim_head, local_heads)) if local_heads > 0 else None - self.to_q = bnb.nn.Linear8bitLt(dim, inner_dim, bias = qkv_bias) - self.to_k = bnb.nn.Linear8bitLt(dim, inner_dim, bias = qkv_bias) - self.to_v = bnb.nn.Linear8bitLt(dim, inner_dim, bias = qkv_bias) - self.to_out = bnb.nn.Linear8bitLt(inner_dim, dim, bias = attn_out_bias) + self.to_q = ml.Linear(dim, inner_dim, bias = qkv_bias) + self.to_k = ml.Linear(dim, inner_dim, bias = qkv_bias) + self.to_v = ml.Linear(dim, inner_dim, bias = qkv_bias) + self.to_out = ml.Linear(inner_dim, dim, bias = attn_out_bias) self.dropout = nn.Dropout(dropout) def forward(self, x, pos_emb = None, context = None, mask = None, context_mask = None, **kwargs): @@ -460,7 +460,7 @@ class AbsolutePositionalEmbedding(nn.Module): def __init__(self, dim, max_seq_len): super().__init__() # nn.Embedding - self.emb = bnb.nn.StableEmbedding(max_seq_len, dim) + self.emb = ml.Embedding(max_seq_len, dim) def forward(self, x): t = torch.arange(x.shape[1], device=x.device) @@ -622,7 +622,7 @@ class PerformerLM(nn.Module): self.max_seq_len = max_seq_len # nn.Embedding - self.token_emb = bnb.nn.StableEmbedding(num_tokens, dim) + self.token_emb = ml.Embedding(num_tokens, dim) if rotary_position_emb: self.pos_emb = FixedPositionalEmbedding(dim, max_seq_len) @@ -639,7 +639,7 @@ class PerformerLM(nn.Module): self.performer = Performer(dim, depth, heads, dim_head, local_attn_heads, local_window_size, causal, ff_mult, nb_features, feature_redraw_interval, reversible, ff_chunks, generalized_attention, kernel_fn, use_scalenorm, use_rezero, ff_glu, ff_dropout, attn_dropout, cross_attend, no_projection, auto_check_redraw, qkv_bias, attn_out_bias, shift_tokens) self.norm = nn.LayerNorm(dim) - self.to_out = bnb.nn.Linear8bitLt(dim, num_tokens) if not tie_embed else None + self.to_out = ml.Linear(dim, num_tokens) if not tie_embed else None def check_redraw_projections(self): self.performer.check_redraw_projections() diff --git a/codes/models/lucidrains/vq.py b/codes/models/lucidrains/vq.py index 058e47fb..13a4b2ae 100644 --- a/codes/models/lucidrains/vq.py +++ b/codes/models/lucidrains/vq.py @@ -8,7 +8,7 @@ from torch.cuda.amp import autocast from einops import rearrange, repeat from contextlib import contextmanager -import bitsandbytes as bnb +import torch_intermediary as ml def par(t, nm): @@ -356,9 +356,9 @@ class VectorQuantize(nn.Module): codebook_dim = default(codebook_dim, dim) requires_projection = codebook_dim != dim - self.project_in = bnb.nn.Linear8bitLt(dim, codebook_dim) if requires_projection \ + self.project_in = ml.Linear(dim, codebook_dim) if requires_projection \ else nn.Identity() - self.project_out = bnb.nn.Linear8bitLt(codebook_dim, dim) if requires_projection \ + self.project_out = ml.Linear(codebook_dim, dim) if requires_projection \ else nn.Identity() self.eps = eps diff --git a/codes/models/lucidrains/x_transformers.py b/codes/models/lucidrains/x_transformers.py index d51618a2..a49af93d 100644 --- a/codes/models/lucidrains/x_transformers.py +++ b/codes/models/lucidrains/x_transformers.py @@ -11,7 +11,7 @@ from einops import rearrange, repeat, reduce from einops.layers.torch import Rearrange from torch.utils.checkpoint import checkpoint -import bitsandbytes as bnb +import torch_intermediary as ml DEFAULT_DIM_HEAD = 64 @@ -127,7 +127,7 @@ class AbsolutePositionalEmbedding(nn.Module): super().__init__() self.scale = dim ** -0.5 # nn.Embedding - self.emb = bnb.nn.StableEmbedding(max_seq_len, dim) + self.emb = ml.Embedding(max_seq_len, dim) def forward(self, x): n = torch.arange(x.shape[1], device=x.device) @@ -157,7 +157,7 @@ class RelativePositionBias(nn.Module): self.num_buckets = num_buckets self.max_distance = max_distance # nn.Embedding - self.relative_attention_bias = bnb.nn.StableEmbedding(num_buckets, heads) + self.relative_attention_bias = ml.Embedding(num_buckets, heads) @staticmethod def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128): @@ -363,7 +363,7 @@ class RMSScaleShiftNorm(nn.Module): self.cdim = 1 self.pdim = -1 else: - self.scale_shift_process = bnb.nn.Linear8bitLt(embed_dim, dim * 2, bias=bias) + self.scale_shift_process = ml.Linear(embed_dim, dim * 2, bias=bias) self.cdim = -1 self.pdim = 1 @@ -450,7 +450,7 @@ class GLU(nn.Module): def __init__(self, dim_in, dim_out, activation): super().__init__() self.act = activation - self.proj = bnb.nn.Linear8bitLt(dim_in, dim_out * 2) + self.proj = ml.Linear(dim_in, dim_out * 2) def forward(self, x): x, gate = self.proj(x).chunk(2, dim=-1) @@ -475,7 +475,7 @@ class FeedForward(nn.Module): activation = ReluSquared() if relu_squared else nn.GELU() project_in = nn.Sequential( - bnb.nn.Linear8bitLt(dim, inner_dim), + ml.Linear(dim, inner_dim), activation ) if not glu else GLU(dim, inner_dim, activation) @@ -483,7 +483,7 @@ class FeedForward(nn.Module): project_in, nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(), nn.Dropout(dropout), - bnb.nn.Linear8bitLt(inner_dim, dim_out) + ml.Linear(inner_dim, dim_out) ) # init last linear layer to 0 @@ -538,16 +538,16 @@ class Attention(nn.Module): qk_dim = int(collab_compression * qk_dim) self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim)) - self.to_q = bnb.nn.Linear8bitLt(dim, qk_dim, bias=False) - self.to_k = bnb.nn.Linear8bitLt(dim, qk_dim, bias=False) - self.to_v = bnb.nn.Linear8bitLt(dim, v_dim, bias=False) + self.to_q = ml.Linear(dim, qk_dim, bias=False) + self.to_k = ml.Linear(dim, qk_dim, bias=False) + self.to_v = ml.Linear(dim, v_dim, bias=False) self.dropout = nn.Dropout(dropout) # add GLU gating for aggregated values, from alphafold2 self.to_v_gate = None if gate_values: - self.to_v_gate = bnb.nn.Linear8bitLt(dim, v_dim) + self.to_v_gate = ml.Linear(dim, v_dim) nn.init.constant_(self.to_v_gate.weight, 0) nn.init.constant_(self.to_v_gate.bias, 1) @@ -584,7 +584,7 @@ class Attention(nn.Module): # attention on attention self.attn_on_attn = on_attn out_dim = default(out_dim, dim) - self.to_out = nn.Sequential(bnb.nn.Linear8bitLt(v_dim, out_dim * 2), nn.GLU()) if on_attn else bnb.nn.Linear8bitLt(v_dim, out_dim) + self.to_out = nn.Sequential(ml.Linear(v_dim, out_dim * 2), nn.GLU()) if on_attn else ml.Linear(v_dim, out_dim) self.rel_pos_bias = rel_pos_bias if rel_pos_bias: @@ -1080,7 +1080,7 @@ class ViTransformerWrapper(nn.Module): self.patch_size = patch_size self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) - self.patch_to_embedding = bnb.nn.Linear8bitLt(patch_dim, dim) + self.patch_to_embedding = ml.Linear(patch_dim, dim) self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) self.dropout = nn.Dropout(emb_dropout) @@ -1139,18 +1139,18 @@ class TransformerWrapper(nn.Module): self.shift_mem_down = shift_mem_down # nn.Embedding - self.token_emb = bnb.nn.StableEmbedding(num_tokens, emb_dim) + self.token_emb = ml.Embedding(num_tokens, emb_dim) self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( use_pos_emb and not attn_layers.has_pos_emb) else always(0) self.emb_dropout = nn.Dropout(emb_dropout) - self.project_emb = bnb.nn.Linear8bitLt(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.project_emb = ml.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() self.attn_layers = attn_layers self.norm = nn.LayerNorm(dim) self.init_() - self.to_logits = bnb.nn.Linear8bitLt(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() + self.to_logits = ml.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() # memory tokens (like [cls]) from Memory Transformers paper num_memory_tokens = default(num_memory_tokens, 0) @@ -1237,12 +1237,12 @@ class ContinuousTransformerWrapper(nn.Module): use_pos_emb and not attn_layers.has_pos_emb) else always(0) self.emb_dropout = nn.Dropout(emb_dropout) - self.project_in = bnb.nn.Linear8bitLt(dim_in, dim) if exists(dim_in) else nn.Identity() + self.project_in = ml.Linear(dim_in, dim) if exists(dim_in) else nn.Identity() self.attn_layers = attn_layers self.norm = nn.LayerNorm(dim) - self.project_out = bnb.nn.Linear8bitLt(dim, dim_out) if exists(dim_out) else nn.Identity() + self.project_out = ml.Linear(dim, dim_out) if exists(dim_out) else nn.Identity() def forward( self, diff --git a/codes/models/vqvae/gumbel_quantizer.py b/codes/models/vqvae/gumbel_quantizer.py index 8b37d1ef..e1d95f26 100644 --- a/codes/models/vqvae/gumbel_quantizer.py +++ b/codes/models/vqvae/gumbel_quantizer.py @@ -4,7 +4,7 @@ import torch.nn.functional as F from torch import einsum from utils.weight_scheduler import LinearDecayWeightScheduler -import bitsandbytes as bnb +import torch_intermediary as ml class GumbelQuantizer(nn.Module): @@ -12,7 +12,7 @@ class GumbelQuantizer(nn.Module): super().__init__() self.to_logits = nn.Conv1d(inp_dim, num_tokens, 1) # nn.Embedding - self.codebook = bnb.nn.StableEmbedding(num_tokens, codebook_dim) + self.codebook = ml.Embedding(num_tokens, codebook_dim) self.straight_through = straight_through self.temperature_scheduler = LinearDecayWeightScheduler(10, 5000, .9, 2000) self.step = 0 diff --git a/codes/models/vqvae/vector_quantizer.py b/codes/models/vqvae/vector_quantizer.py index 1932422b..6d057c25 100644 --- a/codes/models/vqvae/vector_quantizer.py +++ b/codes/models/vqvae/vector_quantizer.py @@ -4,7 +4,7 @@ import torch.nn.functional as F from einops import rearrange, repeat from models.arch_util import l2norm, sample_vectors, default, ema_inplace -import bitsandbytes as bnb +import torch_intermediary as ml def kmeans(samples, num_clusters, num_iters = 10, use_cosine_sim = False): @@ -185,8 +185,8 @@ class VectorQuantize(nn.Module): codebook_dim = default(codebook_dim, dim) requires_projection = codebook_dim != dim - self.project_in = bnb.nn.Linear8bitLt(dim, codebook_dim) if requires_projection else nn.Identity() - self.project_out = bnb.nn.Linear8bitLt(codebook_dim, dim) if requires_projection else nn.Identity() + self.project_in = ml.Linear(dim, codebook_dim) if requires_projection else nn.Identity() + self.project_out = ml.Linear(codebook_dim, dim) if requires_projection else nn.Identity() self.eps = eps diff --git a/codes/torch_intermediary/__init__.py b/codes/torch_intermediary/__init__.py new file mode 100644 index 00000000..bc8933a6 --- /dev/null +++ b/codes/torch_intermediary/__init__.py @@ -0,0 +1,41 @@ +""" +from bitsandbytes.nn import Linear8bitLt as Linear +from bitsandbytes.nn import StableEmbedding as Embedding +from bitsandbytes.optim.adam import Adam8bit as Adam +from bitsandbytes.optim.adamw import AdamW8bit as AdamW +""" +""" +from torch.nn import Linear +from torch.nn import Embedding +from torch.optim.adam import Adam +from torch.optim.adamw import AdamW +""" + +OVERRIDE_LINEAR = False +OVERRIDE_EMBEDDING = False +OVERRIDE_ADAM = True +OVERRIDE_ADAMW = True +USE_STABLE_EMBEDDING = True + +if OVERRIDE_LINEAR: + from bitsandbytes.nn import Linear8bitLt as Linear +else: + from torch.nn import Linear + +if OVERRIDE_EMBEDDING: + if USE_STABLE_EMBEDDING: + from bitsandbytes.nn import StableEmbedding as Embedding + else: + from bitsandbytes.nn import Embedding as Embedding +else: + from torch.nn import Embedding + +if OVERRIDE_ADAM: + from bitsandbytes.optim.adam import Adam8bit as Adam +else: + from torch.optim.adam import Adam + +if OVERRIDE_ADAMW: + from bitsandbytes.optim.adamw import AdamW8bit as AdamW +else: + from torch.optim.adamw import AdamW \ No newline at end of file diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 8c5b43ce..0c80ede7 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -21,7 +21,7 @@ import torchvision.utils as utils from utils.loss_accumulator import LossAccumulator, InfStorageLossAccumulator from utils.util import opt_get, denormalize -import bitsandbytes as bnb +import torch_intermediary as ml logger = logging.getLogger('base') @@ -338,7 +338,7 @@ class ExtensibleTrainer(BaseModel): for net in self.networks.values(): for mod in net.modules(): fan_in = -1 - if isinstance(mod, bnb.nn.Linear8bitLt): + if isinstance(mod, ml.Linear): fan_in = mod.weight.data.shape[1] elif isinstance(mod, nn.Conv1d): fan_in = mod.weight.data.shape[0] diff --git a/codes/trainer/feature_model.py b/codes/trainer/feature_model.py index ea7befd5..c42c9563 100644 --- a/codes/trainer/feature_model.py +++ b/codes/trainer/feature_model.py @@ -7,7 +7,7 @@ import torch.nn as nn import trainer.networks as networks import trainer.lr_scheduler as lr_scheduler from .base_model import BaseModel -import bitsandbytes as bnb +import torch_intermediary as ml logger = logging.getLogger('base') @@ -43,7 +43,7 @@ class FeatureModel(BaseModel): if self.rank <= 0: logger.warning('Params [{:s}] will not optimize.'.format(k)) # torch.optim.Adam - self.optimizer_G = bnb.optim.Adam8bit(optim_params, lr=train_opt['lr_G'], + self.optimizer_G = ml.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1_G'], train_opt['beta2_G'])) self.optimizers.append(self.optimizer_G) diff --git a/codes/trainer/lr_scheduler.py b/codes/trainer/lr_scheduler.py index b22e0d91..ca867e6b 100644 --- a/codes/trainer/lr_scheduler.py +++ b/codes/trainer/lr_scheduler.py @@ -3,7 +3,7 @@ from collections import Counter from collections import defaultdict import torch from torch.optim.lr_scheduler import _LRScheduler -import bitsandbytes as bnb +import torch_intermediary as ml from utils.util import opt_get @@ -137,7 +137,7 @@ class CosineAnnealingLR_Restart(_LRScheduler): if __name__ == "__main__": #torch.optim.Adam - optimizer = bnb.optim.Adam8bit([torch.zeros(3, 64, 3, 3)], lr=1e-4, weight_decay=0, + optimizer = ml.Adam([torch.zeros(3, 64, 3, 3)], lr=1e-4, weight_decay=0, betas=(0.9, 0.99)) ############################## # MultiStepLR_Restart diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index 37cac2f3..2bf75617 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -12,7 +12,7 @@ from utils.util import recursively_detach, opt_get, clip_grad_norm logger = logging.getLogger('base') -import bitsandbytes as bnb +import torch_intermediary as ml # Defines the expected API for a single training step class ConfigurableStep(Module): @@ -84,7 +84,7 @@ class ConfigurableStep(Module): norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d, nn.BatchNorm3d, nn.InstanceNorm3d, nn.GroupNorm, nn.LayerNorm) # nn.Embedding - emb_modules = (bnb.nn.StableEmbedding, nn.EmbeddingBag) + emb_modules = (ml.Embedding, nn.EmbeddingBag) param_names_notweights = set() all_param_names = set() param_map = {} @@ -126,7 +126,7 @@ class ConfigurableStep(Module): { 'params': params_notweights, 'weight_decay': 0 } ] # torch.optim.AdamW - opt = bnb.optim.AdamW8bit(groups, lr=opt_config['lr'], + opt = ml.AdamW(groups, lr=opt_config['lr'], weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2), betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999))) opt._group_names = [params_names_weights, params_names_notweights] @@ -145,7 +145,7 @@ class ConfigurableStep(Module): # parameters and just use a normal AdamW implementation. In a large network, these weights will normally # be a tiny fraction of the total weights. # torch.optim.AdamW - opt_unweighted = bnb.optim.AdamW8bit(params_notweights, lr=opt_config['lr'], weight_decay=0, + opt_unweighted = ml.AdamW(params_notweights, lr=opt_config['lr'], weight_decay=0, betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999))) opt_unweighted._config = opt_config opt_unweighted._config['network'] = net_name @@ -153,7 +153,7 @@ class ConfigurableStep(Module): self.optimizers.append(opt_unweighted) # torch.optim.AdamW - opt = ZeroRedundancyOptimizer(params_weights, optimizer_class=bnb.optim.AdamW8bit, lr=opt_config['lr'], + opt = ZeroRedundancyOptimizer(params_weights, optimizer_class=ml.AdamW, lr=opt_config['lr'], weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2), betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999))) opt.param_groups[0]['initial_lr'] = opt_config['lr'] @@ -168,7 +168,7 @@ class ConfigurableStep(Module): elif self.step_opt['optimizer'] == 'lamb': from trainer.optimizers.lamb import Lamb # torch.optim.AdamW - opt_unweighted = bnb.optim.AdamW8bit(params_notweights, lr=opt_config['lr'], weight_decay=0, + opt_unweighted = ml.AdamW(params_notweights, lr=opt_config['lr'], weight_decay=0, betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999))) opt_unweighted._config = opt_config opt_unweighted._config['network'] = net_name