From 4427d7fb84c6087627ac8b30038f0e0749e08944 Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 22 Feb 2023 23:07:05 +0000 Subject: [PATCH] initial conversion (errors out) --- codes/models/arch_util.py | 7 ++-- codes/models/audio/asr/w2v_wrapper.py | 3 +- codes/models/audio/audio_resnet.py | 3 +- codes/models/audio/mel2vec.py | 23 +++++------ codes/models/audio/music/cheater_gen_ar.py | 6 ++- codes/models/audio/music/diffwave.py | 3 +- codes/models/audio/music/flat_diffusion.py | 7 +++- codes/models/audio/music/gpt_music.py | 11 ++++-- codes/models/audio/music/gpt_music2.py | 6 ++- .../audio/music/instrument_quantizer.py | 7 ++-- codes/models/audio/music/mel2vec_codes_gpt.py | 6 ++- codes/models/audio/music/music_quantizer.py | 3 +- codes/models/audio/music/music_quantizer2.py | 3 +- codes/models/audio/music/tfdpc_v5.py | 8 ++-- .../audio/music/transformer_diffusion12.py | 10 +++-- .../audio/music/transformer_diffusion13.py | 9 +++-- .../audio/music/unet_diffusion_music_codes.py | 8 ++-- codes/models/audio/tts/ctc_code_generator.py | 14 ++++--- codes/models/audio/tts/diffusion_encoder.py | 4 +- codes/models/audio/tts/mini_encoder.py | 5 ++- .../audio/tts/random_latent_converter.py | 3 +- codes/models/audio/tts/tacotron2/layers.py | 3 +- codes/models/audio/tts/tacotron2/tacotron2.py | 4 +- .../audio/tts/tacotron2/wave_tacotron.py | 4 +- .../models/audio/tts/transformer_builders.py | 4 +- .../audio/tts/transformer_diffusion_tts.py | 14 ++++--- .../audio/tts/transformer_diffusion_tts2.py | 18 +++++---- codes/models/audio/tts/unet_diffusion_tts7.py | 7 +++- codes/models/audio/tts/unet_diffusion_tts9.py | 4 +- .../audio/tts/unet_diffusion_tts_flat.py | 5 ++- codes/models/audio/tts/unified_voice2.py | 12 ++++-- codes/models/audio/tts/unified_voice3.py | 14 ++++--- codes/models/audio/tts/unified_voice4.py | 13 ++++--- codes/models/audio/tts/voice_voice_clip.py | 3 +- codes/models/audio/tts/w2v_matcher.py | 8 ++-- codes/models/classifiers/cifar_resnet.py | 3 +- .../classifiers/resnet_with_checkpointing.py | 3 +- codes/models/classifiers/twin_cifar_resnet.py | 3 +- .../classifiers/weighted_conv_resnet.py | 3 +- codes/models/classifiers/wide_kernel_vgg.py | 6 +-- codes/models/clip/clvp.py | 11 ++++-- codes/models/clip/contrastive_audio.py | 8 ++-- codes/models/clip/cvvp.py | 8 ++-- codes/models/clip/mel_text_clip.py | 14 ++++--- codes/models/clip/text_cond_clip.py | 3 +- codes/models/clip/text_voice_clip.py | 17 ++++++--- codes/models/diffusion/nn.py | 3 +- codes/models/diffusion/rrdb_diffusion.py | 7 ++-- codes/models/diffusion/unet_diffusion.py | 12 +++--- codes/models/diffusion/unet_latent_guide.py | 6 ++- .../discriminator_vgg_arch.py | 13 ++++--- .../image_generation/srflow/module_util.py | 3 +- .../stylegan/stylegan2_lucidrains.py | 27 ++++++------- .../image_latents/byol/byol_model_wrapper.py | 9 +++-- .../fixup_resnet/DiscriminatorResnet_arch.py | 7 ++-- codes/models/image_latents/vit_latent.py | 5 ++- codes/models/lucidrains/dalle/attention.py | 13 ++++--- codes/models/lucidrains/dalle/transformer.py | 5 ++- .../lucidrains/performer/performer_pytorch.py | 21 +++++----- codes/models/lucidrains/vq.py | 5 ++- codes/models/lucidrains/x_transformers.py | 38 ++++++++++--------- codes/models/vqvae/gumbel_quantizer.py | 4 +- codes/models/vqvae/vector_quantizer.py | 5 ++- codes/trainer/ExtensibleTrainer.py | 3 +- codes/trainer/feature_model.py | 5 ++- codes/trainer/lr_scheduler.py | 5 ++- codes/trainer/steps.py | 16 +++++--- 67 files changed, 342 insertions(+), 211 deletions(-) diff --git a/codes/models/arch_util.py b/codes/models/arch_util.py index 9383ac7b..48041acd 100644 --- a/codes/models/arch_util.py +++ b/codes/models/arch_util.py @@ -9,6 +9,7 @@ import torch.nn.utils.spectral_norm as SpectralNorm from math import sqrt from utils.util import checkpoint +import bitsandbytes as bnb def exists(val): @@ -73,7 +74,7 @@ def initialize_weights(net_l, scale=1): m.weight.data *= scale # for residual block if m.bias is not None: m.bias.data.zero_() - elif isinstance(m, nn.Linear): + elif isinstance(m, bnb.nn.Linear8bitLt): init.kaiming_normal_(m.weight, a=0, mode='fan_in') m.weight.data *= scale if m.bias is not None: @@ -108,7 +109,7 @@ def default_init_weights(module, scale=1): if isinstance(m, nn.Conv2d): kaiming_init(m, a=0, mode='fan_in', bias=0) m.weight.data *= scale - elif isinstance(m, nn.Linear): + elif isinstance(m, bnb.nn.Linear8bitLt): kaiming_init(m, a=0, mode='fan_in', bias=0) m.weight.data *= scale @@ -141,7 +142,7 @@ def linear(*args, **kwargs): """ Create a linear module. """ - return nn.Linear(*args, **kwargs) + return bnb.nn.Linear8bitLt(*args, **kwargs) def avg_pool_nd(dims, *args, **kwargs): diff --git a/codes/models/audio/asr/w2v_wrapper.py b/codes/models/audio/asr/w2v_wrapper.py index 3d578603..2f1ccdde 100644 --- a/codes/models/audio/asr/w2v_wrapper.py +++ b/codes/models/audio/asr/w2v_wrapper.py @@ -9,6 +9,7 @@ from data.audio.unsupervised_audio_dataset import load_audio from models.audio.tts.tacotron2.text import sequence_to_text from trainer.networks import register_model from utils.util import opt_get +import bitsandbytes as bnb def only_letters(string): @@ -51,7 +52,7 @@ class Wav2VecWrapper(nn.Module): self.w2v = Wav2Vec2ForCTC.from_pretrained(basis_model) # Perform some surgery to get the model we actually want. self.w2v.wav2vec2.encoder.gradient_checkpointing = checkpointing_enabled - self.w2v.lm_head = nn.Linear(self.w2v.config.hidden_size, vocab_size) + self.w2v.lm_head = bnb.nn.Linear8bitLt(self.w2v.config.hidden_size, vocab_size) self.w2v.config.vocab_size = vocab_size self.w2v.config.pad_token_id = 0 self.w2v.config.ctc_loss_reduction = 'sum' diff --git a/codes/models/audio/audio_resnet.py b/codes/models/audio/audio_resnet.py index 0d3c32e7..9c4d63df 100644 --- a/codes/models/audio/audio_resnet.py +++ b/codes/models/audio/audio_resnet.py @@ -5,6 +5,7 @@ import torch.nn as nn from trainer.networks import register_model from utils.util import opt_get from typing import Type, Any, Callable, Union, List, Optional +import bitsandbytes as bnb __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', @@ -172,7 +173,7 @@ class ResNet(nn.Module): self.layer4 = self._make_layer(block, 512, layers[3], stride=4, dilate=replace_stride_with_dilation[2]) self.avgpool = nn.AdaptiveAvgPool1d(1) - self.fc = nn.Linear(512 * block.expansion, num_classes) + self.fc = bnb.nn.Linear8bitLt(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv1d): diff --git a/codes/models/audio/mel2vec.py b/codes/models/audio/mel2vec.py index 568f0099..bac82b8f 100644 --- a/codes/models/audio/mel2vec.py +++ b/codes/models/audio/mel2vec.py @@ -15,13 +15,14 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled from models.arch_util import ResBlock from trainer.networks import register_model from utils.util import checkpoint +import bitsandbytes as bnb class Mel2Vec2FeatureProjection(nn.Module): def __init__(self, inner_dim, dropout): super().__init__() self.layer_norm = nn.LayerNorm(inner_dim, eps=1e-5) - self.projection = nn.Linear(inner_dim, inner_dim) + self.projection = bnb.nn.Linear8bitLt(inner_dim, inner_dim) self.dropout = nn.Dropout(dropout) def forward(self, hidden_states): @@ -58,10 +59,10 @@ class Wav2Vec2Attention(nn.Module): self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder - self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.k_proj = bnb.nn.Linear8bitLt(embed_dim, embed_dim, bias=bias) + self.v_proj = bnb.nn.Linear8bitLt(embed_dim, embed_dim, bias=bias) + self.q_proj = bnb.nn.Linear8bitLt(embed_dim, embed_dim, bias=bias) + self.out_proj = bnb.nn.Linear8bitLt(embed_dim, embed_dim, bias=bias) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -182,10 +183,10 @@ class Wav2Vec2FeedForward(nn.Module): super().__init__() self.intermediate_dropout = nn.Dropout(dropout) - self.intermediate_dense = nn.Linear(hidden_size, intermediate_size) + self.intermediate_dense = bnb.nn.Linear8bitLt(hidden_size, intermediate_size) self.intermediate_act_fn = F.gelu - self.output_dense = nn.Linear(intermediate_size, hidden_size) + self.output_dense = bnb.nn.Linear8bitLt(intermediate_size, hidden_size) self.output_dropout = nn.Dropout(dropout) def forward(self, hidden_states): @@ -429,7 +430,7 @@ class Mel2Vec(nn.Module): k = math.sqrt(1 / module.projection.in_features) nn.init.uniform_(module.projection.weight, a=-k, b=k) nn.init.uniform_(module.projection.bias, a=-k, b=k) - elif isinstance(module, nn.Linear): + elif isinstance(module, bnb.nn.Linear8bitLt): if self.disable_custom_linear_init: return module.weight.data.normal_(mean=0.0, std=self.linear_init_scale) @@ -510,7 +511,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): self.codevectors = nn.Parameter( torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups) ) - self.weight_proj = nn.Linear(proj_dim, self.num_groups * self.num_vars) + self.weight_proj = bnb.nn.Linear8bitLt(proj_dim, self.num_groups * self.num_vars) # can be decayed for training self.temperature = 2 @@ -606,8 +607,8 @@ class ContrastiveTrainingWrapper(nn.Module): self.inp_length_factor = inp_length_multiplier # make sure that project_hid & project_q are initialized like normal linear layers - self.project_hid = nn.Linear(inner_dim, self.quantizer.codevector_dim) - self.project_q = nn.Linear(self.quantizer.codevector_dim, self.quantizer.codevector_dim) + self.project_hid = bnb.nn.Linear8bitLt(inner_dim, self.quantizer.codevector_dim) + self.project_q = bnb.nn.Linear8bitLt(self.quantizer.codevector_dim, self.quantizer.codevector_dim) self.reconstruction = do_reconstruction_loss if do_reconstruction_loss: diff --git a/codes/models/audio/music/cheater_gen_ar.py b/codes/models/audio/music/cheater_gen_ar.py index 096e1619..9b1c1d9d 100644 --- a/codes/models/audio/music/cheater_gen_ar.py +++ b/codes/models/audio/music/cheater_gen_ar.py @@ -2,6 +2,7 @@ import torch import torch.nn.functional as F from torch import nn from transformers import GPT2Config, GPT2Model +import bitsandbytes as bnb from models.arch_util import AttentionBlock, ResBlock from models.audio.tts.lucidrains_dvae import DiscreteVAE @@ -55,8 +56,9 @@ class ConditioningAR(nn.Module): self.gpt = GPT2Model(self.config) del self.gpt.wte # Unused, we'll do our own embeddings. - self.embeddings = nn.Embedding(num_vectors, dim) - self.head = nn.Linear(dim, num_vectors) + # nn.Embedding + self.embeddings = bnb.nn.StableEmbedding(num_vectors, dim) + self.head = bnb.nn.Linear8bitLt(dim, num_vectors) def forward(self, cheater_codes, conditioning, code_lengths=None, return_latent=False): unused_params = [] diff --git a/codes/models/audio/music/diffwave.py b/codes/models/audio/music/diffwave.py index c7e031c5..cf25bc19 100644 --- a/codes/models/audio/music/diffwave.py +++ b/codes/models/audio/music/diffwave.py @@ -17,6 +17,7 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +import bitsandbytes as bnb from math import sqrt @@ -24,7 +25,7 @@ from torch.utils.checkpoint import checkpoint from trainer.networks import register_model -Linear = nn.Linear +Linear = bnb.nn.Linear8bitLt ConvTranspose2d = nn.ConvTranspose2d diff --git a/codes/models/audio/music/flat_diffusion.py b/codes/models/audio/music/flat_diffusion.py index 9def768a..cbc52a7d 100644 --- a/codes/models/audio/music/flat_diffusion.py +++ b/codes/models/audio/music/flat_diffusion.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch import autocast +import bitsandbytes as bnb from models.arch_util import ResBlock from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear @@ -22,7 +23,8 @@ def is_sequence(t): class MultiGroupEmbedding(nn.Module): def __init__(self, tokens, groups, dim): super().__init__() - self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)]) + # nn.Embedding + self.m = nn.ModuleList([bnb.nn.StableEmbedding(tokens, dim // groups) for _ in range(groups)]) def forward(self, x): h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)] @@ -158,7 +160,8 @@ class FlatDiffusion(nn.Module): # complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive # transformer network. if in_groups is None: - self.embeddings = nn.Embedding(token_count, model_channels) + # nn.Embedding + self.embeddings = bnb.nn.StableEmbedding(token_count, model_channels) else: self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels) self.latent_conditioner = nn.Sequential( diff --git a/codes/models/audio/music/gpt_music.py b/codes/models/audio/music/gpt_music.py index 0e3bfd02..b2281662 100644 --- a/codes/models/audio/music/gpt_music.py +++ b/codes/models/audio/music/gpt_music.py @@ -2,6 +2,7 @@ import torch from torch import nn import torch.nn.functional as F from transformers import GPT2Config, GPT2Model +import bitsandbytes as bnb from models.arch_util import AttentionBlock, ResBlock from models.audio.music.music_quantizer import MusicQuantizer @@ -136,8 +137,9 @@ class GptMusicLower(nn.Module): self.gpt = GPT2Model(self.config) del self.gpt.wte # Unused, we'll do our own embeddings. - self.embeddings = nn.ModuleList([nn.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)]) - self.heads = nn.ModuleList([nn.Linear(dim, num_target_vectors) for _ in range(num_vaes)]) + # nn.Embedding + self.embeddings = nn.ModuleList([bnb.nn.StableEmbedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)]) + self.heads = nn.ModuleList([bnb.nn.Linear8bitLt(dim, num_target_vectors) for _ in range(num_vaes)]) def forward(self, mel, conditioning, return_latent=False): unused_params = [] @@ -238,8 +240,9 @@ class GptMusicUpper(nn.Module): self.gpt = GPT2Model(self.config) del self.gpt.wte # Unused, we'll do our own embeddings. - self.embeddings = nn.ModuleList([nn.Embedding(num_upper_vectors, dim // num_upper_groups) for _ in range(num_upper_groups)]) - self.heads = nn.ModuleList([nn.Linear(dim, num_upper_vectors) for _ in range(num_upper_groups)]) + # nn.Embedding + self.embeddings = nn.ModuleList([bnb.nn.StableEmbedding(num_upper_vectors, dim // num_upper_groups) for _ in range(num_upper_groups)]) + self.heads = nn.ModuleList([bnb.nn.Linear8bitLt(dim, num_upper_vectors) for _ in range(num_upper_groups)]) def forward(self, mel, conditioning, return_latent=False): diff --git a/codes/models/audio/music/gpt_music2.py b/codes/models/audio/music/gpt_music2.py index acaad51f..719d6317 100644 --- a/codes/models/audio/music/gpt_music2.py +++ b/codes/models/audio/music/gpt_music2.py @@ -2,6 +2,7 @@ import torch import torch.nn.functional as F from torch import nn from transformers import GPT2Config, GPT2Model +import bitsandbytes as bnb from models.arch_util import AttentionBlock, ResBlock from models.audio.tts.lucidrains_dvae import DiscreteVAE @@ -73,8 +74,9 @@ class GptMusicLower(nn.Module): self.gpt = GPT2Model(self.config) del self.gpt.wte # Unused, we'll do our own embeddings. - self.embeddings = nn.ModuleList([nn.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)]) - self.heads = nn.ModuleList([nn.Linear(dim, num_target_vectors) for _ in range(num_vaes)]) + # nn.Embedding + self.embeddings = nn.ModuleList([bnb.nn.StableEmbedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)]) + self.heads = nn.ModuleList([bnb.nn.Linear8bitLt(dim, num_target_vectors) for _ in range(num_vaes)]) def forward(self, mel, return_latent=False): unused_params = [] diff --git a/codes/models/audio/music/instrument_quantizer.py b/codes/models/audio/music/instrument_quantizer.py index 00ea462d..0d3a9c98 100644 --- a/codes/models/audio/music/instrument_quantizer.py +++ b/codes/models/audio/music/instrument_quantizer.py @@ -3,6 +3,7 @@ import functools import torch import torch.nn as nn import torch.nn.functional as F +import bitsandbytes as bnb from models.diffusion.nn import timestep_embedding from models.lucidrains.vq import VectorQuantize @@ -21,8 +22,8 @@ class SelfClassifyingHead(nn.Module): use_rmsnorm=True, ff_glu=True, do_checkpointing=False) self.quantizer = VectorQuantize(out_dim, classes, use_cosine_sim=False, threshold_ema_dead_code=2, sample_codebook_temp=init_temperature) - self.to_output = nn.Linear(dim, out_dim) - self.to_decoder = nn.Linear(out_dim, dim) + self.to_output = bnb.nn.Linear8bitLt(dim, out_dim) + self.to_decoder = bnb.nn.Linear8bitLt(out_dim, dim) def do_ar_step(self, x, used_codes): h = self.dec(x) @@ -90,7 +91,7 @@ class InstrumentQuantizer(nn.Module): """ super().__init__() self.op_dim = op_dim - self.proj = nn.Linear(op_dim, dim) + self.proj = bnb.nn.Linear8bitLt(op_dim, dim) self.encoder = nn.ModuleList([VectorResBlock(dim, dropout) for _ in range(enc_depth)]) self.heads = SelfClassifyingHead(dim, num_classes, op_dim, head_depth, class_seq_len, dropout, max_temp) self.min_gumbel_temperature = min_temp diff --git a/codes/models/audio/music/mel2vec_codes_gpt.py b/codes/models/audio/music/mel2vec_codes_gpt.py index 9cbf9104..74c3df7f 100644 --- a/codes/models/audio/music/mel2vec_codes_gpt.py +++ b/codes/models/audio/music/mel2vec_codes_gpt.py @@ -1,6 +1,7 @@ import torch from torch import nn import torch.nn.functional as F +import bitsandbytes as bnb from transformers import GPT2Config, GPT2Model from trainer.networks import register_model @@ -17,8 +18,9 @@ class Mel2VecCodesGpt(nn.Module): n_inner=dim*2) self.gpt = GPT2Model(self.config) del self.gpt.wte # Unused, we'll do our own embeddings. - self.embeddings = nn.ModuleList([nn.Embedding(num_vectors, dim//num_groups) for _ in range(num_groups)]) - self.heads = nn.ModuleList([nn.Linear(dim, num_vectors) for _ in range(num_groups)]) + # nn.Embedding + self.embeddings = nn.ModuleList([bnb.nn.StableEmbedding(num_vectors, dim//num_groups) for _ in range(num_groups)]) + self.heads = nn.ModuleList([bnb.nn.Linear8bitLt(dim, num_vectors) for _ in range(num_groups)]) def forward(self, codes): assert codes.shape[-1] == self.num_groups diff --git a/codes/models/audio/music/music_quantizer.py b/codes/models/audio/music/music_quantizer.py index dd508048..75bf2c2d 100644 --- a/codes/models/audio/music/music_quantizer.py +++ b/codes/models/audio/music/music_quantizer.py @@ -3,6 +3,7 @@ import functools import torch from torch import nn import torch.nn.functional as F +import bitsandbytes as bnb from models.arch_util import zero_module from models.vqvae.vqvae import Quantize @@ -75,7 +76,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): self.codevectors = nn.Parameter( torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups) ) - self.weight_proj = nn.Linear(proj_dim, self.num_groups * self.num_vars) + self.weight_proj = bnb.nn.Linear8bitLt(proj_dim, self.num_groups * self.num_vars) # can be decayed for training self.temperature = 2 diff --git a/codes/models/audio/music/music_quantizer2.py b/codes/models/audio/music/music_quantizer2.py index 8fa73c65..585cf67c 100644 --- a/codes/models/audio/music/music_quantizer2.py +++ b/codes/models/audio/music/music_quantizer2.py @@ -3,6 +3,7 @@ import functools import torch from torch import nn import torch.nn.functional as F +import bitsandbytes as bnb from models.arch_util import zero_module from models.vqvae.vqvae import Quantize @@ -87,7 +88,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): self.codevectors = nn.Parameter( torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups) ) - self.weight_proj = nn.Linear(proj_dim, self.num_groups * self.num_vars) + self.weight_proj = bnb.nn.Linear8bitLt(proj_dim, self.num_groups * self.num_vars) # can be decayed for training self.temperature = 2 diff --git a/codes/models/audio/music/tfdpc_v5.py b/codes/models/audio/music/tfdpc_v5.py index a71650ae..161c1851 100644 --- a/codes/models/audio/music/tfdpc_v5.py +++ b/codes/models/audio/music/tfdpc_v5.py @@ -1,3 +1,4 @@ + import itertools import os import random @@ -7,6 +8,7 @@ import torch.nn as nn import torch.nn.functional as F import torchaudio import torchvision +import bitsandbytes as bnb from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.unet_diffusion import TimestepBlock @@ -54,12 +56,12 @@ class ConcatAttentionBlock(TimestepBlock): self.prenorm = RMSScaleShiftNorm(trunk_dim, embed_dim=time_embed_dim, bias=False) if cond_projection: self.tdim = trunk_dim+cond_dim_hidden - self.cond_project = nn.Linear(cond_dim_in, cond_dim_hidden) + self.cond_project = bnb.nn.Linear8bitLt(cond_dim_in, cond_dim_hidden) else: self.tdim = trunk_dim self.block1 = SubBlock(self.tdim, contraction_dim, heads, dropout, use_conv) self.block2 = SubBlock(self.tdim+contraction_dim*2, contraction_dim, heads, dropout, use_conv) - self.out = nn.Linear(contraction_dim*4, trunk_dim, bias=False) + self.out = bnb.nn.Linear8bitLt(contraction_dim*4, trunk_dim, bias=False) self.out.weight.data.zero_() def forward(self, x, cond, timestep_emb, rotary_emb): @@ -87,7 +89,7 @@ class ConditioningEncoder(nn.Module): self.init = nn.Conv1d(cond_dim, embedding_dim, kernel_size=1) self.time_proj = time_proj if time_proj: - self.time_proj = nn.Linear(time_embed_dim, embedding_dim) + self.time_proj = bnb.nn.Linear8bitLt(time_embed_dim, embedding_dim) self.attn = Encoder( dim=embedding_dim, depth=attn_blocks, diff --git a/codes/models/audio/music/transformer_diffusion12.py b/codes/models/audio/music/transformer_diffusion12.py index 056e6426..f3c60b53 100644 --- a/codes/models/audio/music/transformer_diffusion12.py +++ b/codes/models/audio/music/transformer_diffusion12.py @@ -4,6 +4,7 @@ from time import time import torch import torch.nn as nn import torch.nn.functional as F +import bitsandbytes as bnb from models.arch_util import ResBlock from models.audio.music.gpt_music2 import UpperEncoder, GptMusicLower @@ -27,7 +28,8 @@ def is_sequence(t): class MultiGroupEmbedding(nn.Module): def __init__(self, tokens, groups, dim): super().__init__() - self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)]) + # nn.Embedding + self.m = nn.ModuleList([bnb.nn.StableEmbedding(tokens, dim // groups) for _ in range(groups)]) def forward(self, x): h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)] @@ -68,7 +70,7 @@ class ConcatAttentionBlock(TimestepBlock): self.prenorm = RMSScaleShiftNorm(trunk_dim, embed_dim=time_embed_dim, bias=False) self.block1 = SubBlock(trunk_dim, contraction_dim, heads, dropout) self.block2 = SubBlock(trunk_dim+contraction_dim*2, contraction_dim, heads, dropout) - self.out = nn.Linear(contraction_dim*4, trunk_dim, bias=False) + self.out = bnb.nn.Linear8bitLt(contraction_dim*4, trunk_dim, bias=False) self.out.weight.data.zero_() def forward(self, x, timestep_emb, rotary_emb): @@ -129,7 +131,7 @@ class TransformerDiffusion(nn.Module): ) prenet_heads = prenet_channels//64 - self.input_converter = nn.Linear(input_vec_dim, prenet_channels) + self.input_converter = bnb.nn.Linear8bitLt(input_vec_dim, prenet_channels) self.code_converter = Encoder( dim=prenet_channels, depth=prenet_layers, @@ -145,7 +147,7 @@ class TransformerDiffusion(nn.Module): self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels)) self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) - self.intg = nn.Linear(prenet_channels*2, model_channels) + self.intg = bnb.nn.Linear8bitLt(prenet_channels*2, model_channels) self.layers = TimestepRotaryEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_embed_dim, num_heads, dropout) for _ in range(num_layers)]) self.out = nn.Sequential( diff --git a/codes/models/audio/music/transformer_diffusion13.py b/codes/models/audio/music/transformer_diffusion13.py index e641b3e1..2cecb011 100644 --- a/codes/models/audio/music/transformer_diffusion13.py +++ b/codes/models/audio/music/transformer_diffusion13.py @@ -5,6 +5,7 @@ from random import randrange import torch import torch.nn as nn import torch.nn.functional as F +import bitsandbytes as bnb from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask, cGLU, \ RelativeQKBias @@ -69,13 +70,14 @@ class ConditioningEncoder(nn.Module): super().__init__() attn = [] self.init = nn.Conv1d(spec_dim, hidden_dim, kernel_size=5, stride=2) - self.resolution_embedding = nn.Embedding(num_resolutions, hidden_dim) + # nn.Embedding + self.resolution_embedding = bnb.nn.StableEmbedding(num_resolutions, hidden_dim) self.resolution_embedding.weight.data.mul(.1) # Reduces the relative influence of this embedding from the start. for a in range(attn_blocks): attn.append(AttentionBlock(hidden_dim, num_attn_heads, do_checkpoint=do_checkpointing)) attn.append(ResBlock(hidden_dim, dims=1, checkpointing_enabled=do_checkpointing)) self.attn = nn.Sequential(*attn) - self.out = nn.Linear(hidden_dim, out_dim, bias=False) + self.out = bnb.nn.Linear8bitLt(hidden_dim, out_dim, bias=False) self.dim = hidden_dim self.do_checkpointing = do_checkpointing @@ -131,7 +133,8 @@ class TransformerDiffusion(nn.Module): nn.SiLU(), linear(time_embed_dim, time_proj_dim), ) - self.resolution_embed = nn.Embedding(resolution_steps, time_proj_dim) + # nn.Embedding + self.resolution_embed = bnb.nn.StableEmbedding(resolution_steps, time_proj_dim) self.conditioning_encoder = ConditioningEncoder(in_channels, model_channels, cond_proj_dim, resolution_steps, num_attn_heads=model_channels//64) self.unconditioned_embedding = nn.Parameter(torch.randn(1,cond_proj_dim)) diff --git a/codes/models/audio/music/unet_diffusion_music_codes.py b/codes/models/audio/music/unet_diffusion_music_codes.py index d0818744..47293066 100644 --- a/codes/models/audio/music/unet_diffusion_music_codes.py +++ b/codes/models/audio/music/unet_diffusion_music_codes.py @@ -8,6 +8,7 @@ import torch as th import torch.nn as nn import torch.nn.functional as F import torchvision # For debugging, not actually used. +import bitsandbytes as bnb from models.audio.music.gpt_music import GptMusicLower from models.audio.music.music_quantizer import MusicQuantizer @@ -490,7 +491,7 @@ class UNetMusicModel(nn.Module): ) if self.ar_prior: - self.ar_input = nn.Linear(input_vec_dim, model_channels) + self.ar_input = bnb.nn.Linear8bitLt(input_vec_dim, model_channels) self.ar_prior_intg = Encoder( dim=model_channels, depth=4, @@ -504,7 +505,7 @@ class UNetMusicModel(nn.Module): ff_mult=1, ) else: - self.input_converter = nn.Linear(input_vec_dim, model_channels) + self.input_converter = bnb.nn.Linear8bitLt(input_vec_dim, model_channels) self.code_converter = Encoder( dim=model_channels, depth=4, @@ -521,7 +522,8 @@ class UNetMusicModel(nn.Module): self.x_processor = conv_nd(dims, in_channels, model_channels, 3, padding=1) if self.num_classes is not None: - self.label_emb = nn.Embedding(num_classes, time_embed_dim) + # nn.Embedding + self.label_emb = bnb.nn.StableEmbedding(num_classes, time_embed_dim) self.use_raw_y_as_embedding = use_raw_y_as_embedding assert not ((self.num_classes is not None) and use_raw_y_as_embedding) # These are mutually-exclusive. diff --git a/codes/models/audio/tts/ctc_code_generator.py b/codes/models/audio/tts/ctc_code_generator.py index 68905115..1b3f63bd 100644 --- a/codes/models/audio/tts/ctc_code_generator.py +++ b/codes/models/audio/tts/ctc_code_generator.py @@ -3,6 +3,7 @@ from random import random import torch import torch.nn as nn import torch.nn.functional as F +import bitsandbytes as bnb from models.audio.tts.unet_diffusion_tts7 import CheckpointedLayer from models.lucidrains.x_transformers import Encoder @@ -36,9 +37,12 @@ class CtcCodeGenerator(nn.Module): self.ctc_codes = ctc_codes pred_codes = (max_pad+1)*(max_repeat+1) - self.position_embedding = nn.Embedding(max_length, model_dim) - self.codes_embedding = nn.Embedding(ctc_codes, model_dim) - self.recursive_embedding = nn.Embedding(pred_codes, model_dim) + # nn.Embedding + self.position_embedding = bnb.nn.StableEmbedding(max_length, model_dim) + # nn.Embedding + self.codes_embedding = bnb.nn.StableEmbedding(ctc_codes, model_dim) + # nn.Embedding + self.recursive_embedding = bnb.nn.StableEmbedding(pred_codes, model_dim) self.mask_embedding = nn.Parameter(torch.randn(model_dim)) self.encoder = Encoder( dim=model_dim, @@ -50,8 +54,8 @@ class CtcCodeGenerator(nn.Module): ff_glu=True, rotary_pos_emb=True, ) - self.pred_head = nn.Linear(model_dim, pred_codes) - self.confidence_head = nn.Linear(model_dim, 1) + self.pred_head = bnb.nn.Linear8bitLt(model_dim, pred_codes) + self.confidence_head = bnb.nn.Linear8bitLt(model_dim, 1) def inference(self, codes, pads, repeats): position_h = self.position_embedding(torch.arange(0, codes.shape[-1], device=codes.device)) diff --git a/codes/models/audio/tts/diffusion_encoder.py b/codes/models/audio/tts/diffusion_encoder.py index 1dc91c0d..29eddccc 100644 --- a/codes/models/audio/tts/diffusion_encoder.py +++ b/codes/models/audio/tts/diffusion_encoder.py @@ -5,6 +5,8 @@ from functools import partial import torch import torch.nn as nn +import bitsandbytes as bnb + from x_transformers.x_transformers import groupby_prefix_and_trim, FixedPositionalEmbedding, default, RotaryEmbedding, \ DEFAULT_DIM_HEAD, RelativePositionBias, LearnedAlibiPositionalBias, AlibiPositionalBias, ScaleNorm, RMSNorm, \ exists, Attention, FeedForward, Scale, ShiftTokens, GRUGating, Residual, cast_tuple, equals, LayerIntermediates, \ @@ -16,7 +18,7 @@ class TimeIntegrationBlock(nn.Module): super().__init__() self.emb_layers = nn.Sequential( nn.SiLU(), - nn.Linear( + bnb.nn.Linear8bitLt( time_emb_dim, 2 * dim ), diff --git a/codes/models/audio/tts/mini_encoder.py b/codes/models/audio/tts/mini_encoder.py index 23283173..e8fe5498 100644 --- a/codes/models/audio/tts/mini_encoder.py +++ b/codes/models/audio/tts/mini_encoder.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +import bitsandbytes as bnb from models.diffusion.nn import normalization, conv_nd, zero_module @@ -138,7 +139,7 @@ class AudioMiniEncoderWithClassifierHead(nn.Module): def __init__(self, classes, distribute_zero_label=True, **kwargs): super().__init__() self.enc = AudioMiniEncoder(**kwargs) - self.head = nn.Linear(self.enc.dim, classes) + self.head = bnb.nn.Linear8bitLt(self.enc.dim, classes) self.num_classes = classes self.distribute_zero_label = distribute_zero_label @@ -183,7 +184,7 @@ class QueryProvidedAttentionBlock(nn.Module): ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" self.num_heads = channels // num_head_channels self.norm = normalization(channels) - self.q = nn.Linear(channels, channels) + self.q = bnb.nn.Linear8bitLt(channels, channels) self.qnorm = nn.LayerNorm(channels) self.kv = conv_nd(1, channels, channels*2, 1) if use_new_attention_order: diff --git a/codes/models/audio/tts/random_latent_converter.py b/codes/models/audio/tts/random_latent_converter.py index d4b5dd00..c5331786 100644 --- a/codes/models/audio/tts/random_latent_converter.py +++ b/codes/models/audio/tts/random_latent_converter.py @@ -3,6 +3,7 @@ import math import torch import torch.nn as nn import torch.nn.functional as F +import bitsandbytes as bnb from trainer.networks import register_model from utils.util import opt_get @@ -44,7 +45,7 @@ class RandomLatentConverter(nn.Module): def __init__(self, channels): super().__init__() self.layers = nn.Sequential(*[EqualLinear(channels, channels, lr_mul=.1) for _ in range(5)], - nn.Linear(channels, channels)) + bnb.nn.Linear8bitLt(channels, channels)) self.channels = channels def forward(self, ref): diff --git a/codes/models/audio/tts/tacotron2/layers.py b/codes/models/audio/tts/tacotron2/layers.py index 8b69e0df..925ddffd 100644 --- a/codes/models/audio/tts/tacotron2/layers.py +++ b/codes/models/audio/tts/tacotron2/layers.py @@ -3,12 +3,13 @@ from librosa.filters import mel as librosa_mel_fn from models.audio.tts.tacotron2.audio_processing import dynamic_range_compression from models.audio.tts.tacotron2.audio_processing import dynamic_range_decompression from models.audio.tts.tacotron2.stft import STFT +import bitsandbytes as bnb class LinearNorm(torch.nn.Module): def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): super(LinearNorm, self).__init__() - self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) + self.linear_layer = torch.bnb.nn.Linear8bitLt(in_dim, out_dim, bias=bias) torch.nn.init.xavier_uniform_( self.linear_layer.weight, diff --git a/codes/models/audio/tts/tacotron2/tacotron2.py b/codes/models/audio/tts/tacotron2/tacotron2.py index 73ecc754..13e9bef1 100644 --- a/codes/models/audio/tts/tacotron2/tacotron2.py +++ b/codes/models/audio/tts/tacotron2/tacotron2.py @@ -8,6 +8,7 @@ from models.audio.tts.tacotron2.layers import ConvNorm, LinearNorm from models.audio.tts.tacotron2.hparams import create_hparams from trainer.networks import register_model from models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths +import bitsandbytes as bnb class LocationLayer(nn.Module): @@ -463,7 +464,8 @@ class Tacotron2(nn.Module): self.fp16_run = hparams.fp16_run self.n_mel_channels = hparams.n_mel_channels self.n_frames_per_step = hparams.n_frames_per_step - self.embedding = nn.Embedding( + # nn.Embedding + self.embedding = bnb.nn.StableEmbedding( hparams.n_symbols, hparams.symbols_embedding_dim) std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim)) val = sqrt(3.0) * std # uniform bounds for std diff --git a/codes/models/audio/tts/tacotron2/wave_tacotron.py b/codes/models/audio/tts/tacotron2/wave_tacotron.py index 4c02e05e..510f94c6 100644 --- a/codes/models/audio/tts/tacotron2/wave_tacotron.py +++ b/codes/models/audio/tts/tacotron2/wave_tacotron.py @@ -13,6 +13,7 @@ from models.audio.tts.tacotron2.tacotron2 import Attention, Encoder from trainer.networks import register_model from models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths from utils.util import checkpoint +import bitsandbytes as bnb @@ -185,7 +186,8 @@ class WaveTacotron2(nn.Module): self.fp16_run = hparams.fp16_run self.n_mel_channels = hparams.n_mel_channels self.n_frames_per_step = hparams.n_frames_per_step - self.embedding = nn.Embedding( + # nn.Embedding + self.embedding = bnb.nn.StableEmbedding( hparams.n_symbols, hparams.symbols_embedding_dim) std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim)) val = sqrt(3.0) * std # uniform bounds for std diff --git a/codes/models/audio/tts/transformer_builders.py b/codes/models/audio/tts/transformer_builders.py index 238dab66..74e7ceb2 100644 --- a/codes/models/audio/tts/transformer_builders.py +++ b/codes/models/audio/tts/transformer_builders.py @@ -25,6 +25,7 @@ import random from time import time import torch import torch.nn as nn +import bitsandbytes as bnb from tqdm import tqdm @@ -35,7 +36,8 @@ def null_position_embeddings(range, dim): class LearnedPositionEmbeddings(nn.Module): def __init__(self, seq_len, model_dim, init=.02, relative=False): super().__init__() - self.emb = nn.Embedding(seq_len, model_dim) + # nn.Embedding + self.emb = bnb.nn.StableEmbedding(seq_len, model_dim) # Initializing this way is standard for GPT-2 self.emb.weight.data.normal_(mean=0.0, std=init) self.relative = relative diff --git a/codes/models/audio/tts/transformer_diffusion_tts.py b/codes/models/audio/tts/transformer_diffusion_tts.py index 430ebc4a..00b6485e 100644 --- a/codes/models/audio/tts/transformer_diffusion_tts.py +++ b/codes/models/audio/tts/transformer_diffusion_tts.py @@ -7,6 +7,7 @@ from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlo from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding from trainer.networks import register_model from utils.util import checkpoint +import bitsandbytes as bnb def is_latent(t): @@ -19,7 +20,8 @@ def is_sequence(t): class MultiGroupEmbedding(nn.Module): def __init__(self, tokens, groups, dim): super().__init__() - self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)]) + # nn.Embedding + self.m = nn.ModuleList([bnb.nn.StableEmbedding(tokens, dim // groups) for _ in range(groups)]) def forward(self, x): h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)] @@ -100,15 +102,17 @@ class TransformerDiffusionTTS(nn.Module): ff_glu=True, rotary_pos_emb=True, ) - self.clvp_encoder = nn.Linear(clvp_in_dim, model_channels) - self.type_embedding = nn.Embedding(types, model_channels) + self.clvp_encoder = bnb.nn.Linear8bitLt(clvp_in_dim, model_channels) + # nn.Embedding + self.type_embedding = bnb.nn.StableEmbedding(types, model_channels) # Either code_converter or latent_converter is used, depending on what type of conditioning data is fed. # This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally # complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive # transformer network. if in_groups is None: - self.embeddings = nn.Embedding(token_count, model_channels) + # nn.Embedding + self.embeddings = bnb.nn.StableEmbedding(token_count, model_channels) else: self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels) self.latent_conditioner = nn.Sequential( @@ -140,7 +144,7 @@ class TransformerDiffusionTTS(nn.Module): self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1) self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) - self.intg = nn.Linear(model_channels*2, model_channels) + self.intg = bnb.nn.Linear8bitLt(model_channels*2, model_channels) self.layers = TimestepRotaryEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers)]) self.out = nn.Sequential( diff --git a/codes/models/audio/tts/transformer_diffusion_tts2.py b/codes/models/audio/tts/transformer_diffusion_tts2.py index a1351539..cfe6b625 100644 --- a/codes/models/audio/tts/transformer_diffusion_tts2.py +++ b/codes/models/audio/tts/transformer_diffusion_tts2.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +import bitsandbytes as bnb from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock @@ -19,7 +20,8 @@ def is_sequence(t): class MultiGroupEmbedding(nn.Module): def __init__(self, tokens, groups, dim): super().__init__() - self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)]) + # nn.Embedding + self.m = nn.ModuleList([bnb.nn.StableEmbedding(tokens, dim // groups) for _ in range(groups)]) def forward(self, x): h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)] @@ -40,7 +42,7 @@ class DietAttentionBlock(TimestepBlock): def __init__(self, in_dim, dim, heads, dropout): super().__init__() self.rms_scale_norm = RMSScaleShiftNorm(in_dim) - self.proj = nn.Linear(in_dim, dim) + self.proj = bnb.nn.Linear8bitLt(in_dim, dim) self.attn = Attention(dim, heads=heads, causal=False, dropout=dropout) self.ff = FeedForward(dim, in_dim, mult=1, dropout=dropout, zero_init_output=True) @@ -105,15 +107,17 @@ class TransformerDiffusionTTS(nn.Module): ff_glu=True, rotary_pos_emb=True, ) - self.clvp_encoder = nn.Linear(clvp_in_dim, prenet_channels) - self.type_embedding = nn.Embedding(types, prenet_channels) + self.clvp_encoder = bnb.nn.Linear8bitLt(clvp_in_dim, prenet_channels) + # nn.Embedding + self.type_embedding = bnb.nn.StableEmbedding(types, prenet_channels) # Either code_converter or latent_converter is used, depending on what type of conditioning data is fed. # This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally # complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive # transformer network. if in_groups is None: - self.embeddings = nn.Embedding(token_count, prenet_channels) + # nn.Embedding + self.embeddings = bnb.nn.StableEmbedding(token_count, prenet_channels) else: self.embeddings = MultiGroupEmbedding(token_count, in_groups, prenet_channels) self.latent_conditioner = nn.Sequential( @@ -144,8 +148,8 @@ class TransformerDiffusionTTS(nn.Module): self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels)) self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) - self.cond_intg = nn.Linear(prenet_channels*4, model_channels) - self.intg = nn.Linear(prenet_channels*2, model_channels) + self.cond_intg = bnb.nn.Linear8bitLt(prenet_channels*4, model_channels) + self.intg = bnb.nn.Linear8bitLt(prenet_channels*2, model_channels) self.layers = TimestepRotaryEmbedSequential(*[DietAttentionBlock(model_channels, block_channels, block_channels // 64, dropout) for _ in range(num_layers)]) diff --git a/codes/models/audio/tts/unet_diffusion_tts7.py b/codes/models/audio/tts/unet_diffusion_tts7.py index a323c9b0..ccfb4735 100644 --- a/codes/models/audio/tts/unet_diffusion_tts7.py +++ b/codes/models/audio/tts/unet_diffusion_tts7.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch import autocast +import bitsandbytes as bnb from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \ @@ -247,14 +248,16 @@ class DiffusionTts(nn.Module): ) embedding_dim = model_channels * 8 - self.code_embedding = nn.Embedding(num_tokens+1, embedding_dim) + # nn.Embedding + self.code_embedding = bnb.nn.StableEmbedding(num_tokens+1, embedding_dim) self.contextual_embedder = AudioMiniEncoder(1, embedding_dim, base_channels=32, depth=6, resnet_blocks=1, attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5) self.conditioning_conv = nn.Conv1d(embedding_dim*3, embedding_dim, 1) self.enable_unaligned_inputs = enabled_unaligned_inputs if enabled_unaligned_inputs: - self.unaligned_embedder = nn.Embedding(num_unaligned_tokens, embedding_dim) + # nn.Embedding + self.unaligned_embedder = bnb.nn.StableEmbedding(num_unaligned_tokens, embedding_dim) self.unaligned_encoder = CheckpointedXTransformerEncoder( max_seq_len=-1, use_pos_emb=False, diff --git a/codes/models/audio/tts/unet_diffusion_tts9.py b/codes/models/audio/tts/unet_diffusion_tts9.py index 278fde70..341ff474 100644 --- a/codes/models/audio/tts/unet_diffusion_tts9.py +++ b/codes/models/audio/tts/unet_diffusion_tts9.py @@ -5,6 +5,7 @@ import torch.nn as nn import torch.nn.functional as F from torch import autocast from x_transformers import Encoder +import bitsandbytes as bnb from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \ @@ -206,7 +207,8 @@ class DiffusionTts(nn.Module): # complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive # transformer network. self.code_converter = nn.Sequential( - nn.Embedding(in_tokens, conditioning_dim), + # nn.Embedding + bnb.nn.StableEmbedding(in_tokens, conditioning_dim), CheckpointedXTransformerEncoder( needs_permute=False, max_seq_len=-1, diff --git a/codes/models/audio/tts/unet_diffusion_tts_flat.py b/codes/models/audio/tts/unet_diffusion_tts_flat.py index 6b034680..9ee14de4 100644 --- a/codes/models/audio/tts/unet_diffusion_tts_flat.py +++ b/codes/models/audio/tts/unet_diffusion_tts_flat.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch import autocast +import bitsandbytes as bnb from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock, QKVAttentionLegacy @@ -193,7 +194,9 @@ class DiffusionTtsFlat(nn.Module): # This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally # complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive # transformer network. - self.code_embedding = nn.Embedding(in_tokens, model_channels) + + # nn.Embedding + self.code_embedding = bnb.nn.StableEmbedding(in_tokens, model_channels) self.code_converter = nn.Sequential( AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), diff --git a/codes/models/audio/tts/unified_voice2.py b/codes/models/audio/tts/unified_voice2.py index 435526c6..b3382f3a 100644 --- a/codes/models/audio/tts/unified_voice2.py +++ b/codes/models/audio/tts/unified_voice2.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F + from transformers import GPT2Config, GPT2PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions from transformers.models.gpt2.modeling_gpt2 import GPT2Attention @@ -12,6 +13,7 @@ from models.lucidrains.x_transformers import RotaryEmbedding, apply_rotary_pos_e from trainer.networks import register_model from utils.util import opt_get +import bitsandbytes as bnb class ResBlock(nn.Module): """ @@ -279,9 +281,11 @@ class UnifiedVoice(nn.Module): self.mel_length_compression = mel_length_compression self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) self.average_conditioning_embeddings = average_conditioning_embeddings - self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim) + # nn.Embedding + self.text_embedding = bnb.nn.StableEmbedding(self.number_text_tokens, model_dim) if use_mel_codes_as_input: - self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim) + # nn.Embedding + self.mel_embedding = bnb.nn.StableEmbedding(self.number_mel_codes, model_dim) else: self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1) self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \ @@ -294,8 +298,8 @@ class UnifiedVoice(nn.Module): self.text_solo_embedding = 0 self.final_norm = nn.LayerNorm(model_dim) - self.text_head = nn.Linear(model_dim, self.number_text_tokens) - self.mel_head = nn.Linear(model_dim, self.number_mel_codes) + self.text_head = bnb.nn.Linear8bitLt(model_dim, self.number_text_tokens) + self.mel_head = bnb.nn.Linear8bitLt(model_dim, self.number_mel_codes) # Initialize the embeddings per the GPT-2 scheme embeddings = [self.text_embedding] diff --git a/codes/models/audio/tts/unified_voice3.py b/codes/models/audio/tts/unified_voice3.py index dd26c789..1dfce7f6 100644 --- a/codes/models/audio/tts/unified_voice3.py +++ b/codes/models/audio/tts/unified_voice3.py @@ -1,6 +1,8 @@ import torch import torch.nn as nn import torch.nn.functional as F +import bitsandbytes as bnb + from transformers import GPT2Config, GPT2PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions from transformers.models.gpt2.modeling_gpt2 import GPT2Attention @@ -271,15 +273,17 @@ class UnifiedVoice(nn.Module): self.model_dim = model_dim self.mel_length_compression = mel_length_compression self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) - self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim) - self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim) + # nn.Embedding + self.text_embedding = bnb.nn.StableEmbedding(self.number_text_tokens*types+1, model_dim) + # nn.Embedding + self.mel_embedding = bnb.nn.StableEmbedding(self.number_mel_codes, model_dim) self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \ build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing) self.final_norm = nn.LayerNorm(model_dim) - self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1) - self.mel_head = nn.Linear(model_dim, self.number_mel_codes) - self.aligned_head = nn.Linear(model_dim, number_aligned_text_codes) + self.text_head = bnb.nn.Linear8bitLt(model_dim, self.number_text_tokens*types+1) + self.mel_head = bnb.nn.Linear8bitLt(model_dim, self.number_mel_codes) + self.aligned_head = bnb.nn.Linear8bitLt(model_dim, number_aligned_text_codes) # Initialize the embeddings per the GPT-2 scheme embeddings = [self.text_embedding, self.mel_embedding] diff --git a/codes/models/audio/tts/unified_voice4.py b/codes/models/audio/tts/unified_voice4.py index a186260d..d76d3cc3 100644 --- a/codes/models/audio/tts/unified_voice4.py +++ b/codes/models/audio/tts/unified_voice4.py @@ -11,6 +11,7 @@ from models.audio.tts.transformer_builders import build_hf_gpt_transformer from models.lucidrains.x_transformers import RotaryEmbedding, apply_rotary_pos_emb from trainer.networks import register_model from utils.util import opt_get +import bitsandbytes as bnb class ResBlock(nn.Module): @@ -255,15 +256,17 @@ class UnifiedVoice(nn.Module): self.model_dim = model_dim self.mel_length_compression = mel_length_compression self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) - self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim) - self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim) + # nn.Embedding + self.text_embedding = bnb.nn.StableEmbedding(self.number_text_tokens*types+1, model_dim) + # nn.Embedding + self.mel_embedding = bnb.nn.StableEmbedding(self.number_mel_codes, model_dim) self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \ build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing) self.final_norm = nn.LayerNorm(model_dim) - self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1) - self.mel_head = nn.Linear(model_dim, self.number_mel_codes) - self.alignment_head = nn.Linear(model_dim, 256) + self.text_head = bnb.nn.Linear8bitLt(model_dim, self.number_text_tokens*types+1) + self.mel_head = bnb.nn.Linear8bitLt(model_dim, self.number_mel_codes) + self.alignment_head = bnb.nn.Linear8bitLt(model_dim, 256) if only_alignment_head: for p in self.parameters(): diff --git a/codes/models/audio/tts/voice_voice_clip.py b/codes/models/audio/tts/voice_voice_clip.py index 22fa2d80..11401668 100644 --- a/codes/models/audio/tts/voice_voice_clip.py +++ b/codes/models/audio/tts/voice_voice_clip.py @@ -8,6 +8,7 @@ from models.audio.tts.mini_encoder import AudioMiniEncoder from trainer.injectors.spec_augment import spec_augment from trainer.networks import register_model from utils.util import opt_get +import bitsandbytes as bnb def exists(val): @@ -36,7 +37,7 @@ class VoiceCLIP(nn.Module): self.encoder = AudioMiniEncoder(80, encoder_output) if pretrained_encoder_dict_path is not None: self.encoder.load_state_dict(torch.load(pretrained_encoder_dict_path)) - self.to_latent = nn.Linear(encoder_output, dim_latent, bias=False) + self.to_latent = bnb.nn.Linear8bitLt(encoder_output, dim_latent, bias=False) self.temperature = nn.Parameter(torch.tensor(1.)) self.mel_compression_ratio = mel_compression_ratio diff --git a/codes/models/audio/tts/w2v_matcher.py b/codes/models/audio/tts/w2v_matcher.py index a2d261a6..61ac850a 100644 --- a/codes/models/audio/tts/w2v_matcher.py +++ b/codes/models/audio/tts/w2v_matcher.py @@ -7,6 +7,7 @@ from x_transformers import Encoder, Decoder, ContinuousTransformerWrapper from models.audio.tts.mini_encoder import AudioMiniEncoder from trainer.networks import register_model +import bitsandbytes as bnb class CheckpointedLayer(nn.Module): @@ -56,7 +57,8 @@ class Wav2VecMatcher(nn.Module): WAV2VEC_CHANNELS = 1024 self.conditioning_encoder = AudioMiniEncoder(1, model_dim, base_channels=32, depth=6, resnet_blocks=1, attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5) - self.text_embedding = nn.Embedding(num_text_tokens, model_dim) + # nn.Embedding + self.text_embedding = bnb.nn.StableEmbedding(num_text_tokens, model_dim) self.encoder = CheckpointedXTransformer( max_seq_len=-1, use_pos_emb=False, @@ -73,8 +75,8 @@ class Wav2VecMatcher(nn.Module): ) self.decoder_start_embedding = nn.Parameter(torch.randn(1,1,model_dim)) self.decoder_stop_embedding = nn.Parameter(torch.randn(1,model_dim)) - self.w2v_query_encoder = nn.Linear(WAV2VEC_CHANNELS, model_dim) - self.w2v_value_encoder = nn.Linear(WAV2VEC_CHANNELS, model_dim) + self.w2v_query_encoder = bnb.nn.Linear8bitLt(WAV2VEC_CHANNELS, model_dim) + self.w2v_value_encoder = bnb.nn.Linear8bitLt(WAV2VEC_CHANNELS, model_dim) self.decoder = CheckpointedXTransformer( max_seq_len=-1, # Should be unused use_pos_emb=False, diff --git a/codes/models/classifiers/cifar_resnet.py b/codes/models/classifiers/cifar_resnet.py index ceb78064..d4efbdf4 100644 --- a/codes/models/classifiers/cifar_resnet.py +++ b/codes/models/classifiers/cifar_resnet.py @@ -10,6 +10,7 @@ import torch import torch.nn as nn +import bitsandbytes as bnb from trainer.networks import register_model @@ -98,7 +99,7 @@ class ResNet(nn.Module): self.conv4_x = self._make_layer(block, 128, num_block[2], 2) self.conv5_x = self._make_layer(block, 256, num_block[3], 2) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - self.fc = nn.Linear(256 * block.expansion, num_classes) + self.fc = bnb.nn.Linear8bitLt(256 * block.expansion, num_classes) def _make_layer(self, block, out_channels, num_blocks, stride): """make resnet layers(by layer i didnt mean this 'layer' was the diff --git a/codes/models/classifiers/resnet_with_checkpointing.py b/codes/models/classifiers/resnet_with_checkpointing.py index 552d1328..526edcd9 100644 --- a/codes/models/classifiers/resnet_with_checkpointing.py +++ b/codes/models/classifiers/resnet_with_checkpointing.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn from torchvision.models.resnet import BasicBlock, Bottleneck import torchvision +import bitsandbytes as bnb __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', @@ -194,5 +195,5 @@ def wide_resnet101_2(pretrained=False, progress=True, **kwargs): def register_resnet50(opt_net, opt): model = resnet50(pretrained=opt_net['pretrained']) if opt_net['custom_head_logits']: - model.fc = nn.Linear(512 * 4, opt_net['custom_head_logits']) + model.fc = bnb.nn.Linear8bitLt(512 * 4, opt_net['custom_head_logits']) return model diff --git a/codes/models/classifiers/twin_cifar_resnet.py b/codes/models/classifiers/twin_cifar_resnet.py index 6aa1f938..2025f65a 100644 --- a/codes/models/classifiers/twin_cifar_resnet.py +++ b/codes/models/classifiers/twin_cifar_resnet.py @@ -11,6 +11,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +import bitsandbytes as bnb from trainer.networks import register_model @@ -101,7 +102,7 @@ class ResNet(nn.Module): self.conv4_x = self._make_layer(block, 128, num_block[2], 2) self.conv5_x = self._make_layer(block, 256, num_block[3], 2) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - self.fc = nn.Linear(256 * block.expansion, num_classes) + self.fc = bnb.nn.Linear8bitLt(256 * block.expansion, num_classes) def _make_layer(self, block, out_channels, num_blocks, stride): """make resnet layers(by layer i didnt mean this 'layer' was the diff --git a/codes/models/classifiers/weighted_conv_resnet.py b/codes/models/classifiers/weighted_conv_resnet.py index f69c2cb4..dfbb6724 100644 --- a/codes/models/classifiers/weighted_conv_resnet.py +++ b/codes/models/classifiers/weighted_conv_resnet.py @@ -11,6 +11,7 @@ __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', from models.vqvae.scaled_weight_conv import ScaledWeightConv from trainer.networks import register_model from utils.util import checkpoint +import bitsandbytes as bnb model_urls = { 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', @@ -213,7 +214,7 @@ class ResNet(nn.Module): self.layer4 = self._make_layer(block, 512, layers[3], breadth, stride=2, dilate=replace_stride_with_dilation[2]) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - self.fc = nn.Linear(512 * block.expansion, num_classes) + self.fc = bnb.nn.Linear8bitLt(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, ScaledWeightConv): diff --git a/codes/models/classifiers/wide_kernel_vgg.py b/codes/models/classifiers/wide_kernel_vgg.py index d4af0ee0..0bb02f7b 100644 --- a/codes/models/classifiers/wide_kernel_vgg.py +++ b/codes/models/classifiers/wide_kernel_vgg.py @@ -3,7 +3,7 @@ import torch.nn as nn from trainer.networks import register_model from utils.util import opt_get - +import bitsandbytes as bnb class WideKernelVgg(nn.Module): def __init__(self, nf=64, num_classes=2): @@ -49,9 +49,9 @@ class WideKernelVgg(nn.Module): nn.ReLU(), nn.MaxPool2d(kernel_size=2), nn.Flatten(), - nn.Linear(nf * 8 * 4 * 2, 100), + bnb.nn.Linear8bitLt(nf * 8 * 4 * 2, 100), nn.ReLU(), - nn.Linear(100, num_classes) + bnb.nn.Linear8bitLt(100, num_classes) ) # These normalization constants should be derived experimentally. diff --git a/codes/models/clip/clvp.py b/codes/models/clip/clvp.py index 7f4a3461..c22da91f 100644 --- a/codes/models/clip/clvp.py +++ b/codes/models/clip/clvp.py @@ -10,6 +10,7 @@ from models.arch_util import AttentionBlock from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder from trainer.networks import register_model from utils.util import opt_get, checkpoint +import bitsandbytes as bnb def exists(val): @@ -58,7 +59,8 @@ class CollapsingTransformer(nn.Module): class ConvFormatEmbedding(nn.Module): def __init__(self, *args, **kwargs): super().__init__() - self.emb = nn.Embedding(*args, **kwargs) + # nn.Embedding + self.emb = bnb.nn.StableEmbedding(*args, **kwargs) def forward(self, x): y = self.emb(x) @@ -98,9 +100,10 @@ class CLVP(nn.Module): self.masked_conditioning_latent = nn.Parameter(torch.randn(1,model_dim*2), requires_grad=True) self.mask_conditioning_percentage = mask_conditioning_percentage - self.text_emb = nn.Embedding(num_text_tokens, model_dim) + # nn.Embedding + self.text_emb = bnb.nn.StableEmbedding(num_text_tokens, model_dim) self.text_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, text_enc_depth, text_mask_percentage, use_rms_scaleshift_norm=True) - self.to_text_latent = nn.Linear(latent_dim, latent_dim, bias=False) + self.to_text_latent = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False) self.distributed_collect = distributed_collect if mel_codes is None: @@ -108,7 +111,7 @@ class CLVP(nn.Module): else: self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim) self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage) - self.to_speech_latent = nn.Linear(latent_dim, latent_dim, bias=False) + self.to_speech_latent = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False) def get_grad_norm_parameter_groups(self): return { diff --git a/codes/models/clip/contrastive_audio.py b/codes/models/clip/contrastive_audio.py index 6bd77e55..ce53ddce 100644 --- a/codes/models/clip/contrastive_audio.py +++ b/codes/models/clip/contrastive_audio.py @@ -9,6 +9,7 @@ from models.arch_util import AttentionBlock from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder from trainer.networks import register_model from utils.util import opt_get, checkpoint +import bitsandbytes as bnb def exists(val): @@ -178,7 +179,8 @@ class CollapsingTransformer(nn.Module): class ConvFormatEmbedding(nn.Module): def __init__(self, *args, **kwargs): super().__init__() - self.emb = nn.Embedding(*args, **kwargs) + # nn.Embedding + self.emb = bnb.nn.StableEmbedding(*args, **kwargs) def forward(self, x): y = self.emb(x) @@ -203,8 +205,8 @@ class ContrastiveAudio(nn.Module): self.emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim // 2, kernel_size=5, stride=2, padding=2), nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1)) self.transformer = CollapsingTransformer(model_dim, model_dim, transformer_heads, dropout, encoder_depth, mask_percent) - self.to_latent = nn.Linear(latent_dim, latent_dim, bias=False) - self.to_latent2 = nn.Linear(latent_dim, latent_dim, bias=False) + self.to_latent = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False) + self.to_latent2 = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False) self.to_latent2.weight.data = self.to_latent.weight.data self.to_latent2.weight.DO_NOT_TRAIN = True diff --git a/codes/models/clip/cvvp.py b/codes/models/clip/cvvp.py index 2ad7eca6..24567265 100644 --- a/codes/models/clip/cvvp.py +++ b/codes/models/clip/cvvp.py @@ -10,6 +10,7 @@ from models.arch_util import AttentionBlock from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder from trainer.networks import register_model from utils.util import opt_get, checkpoint +import bitsandbytes as bnb def exists(val): @@ -58,7 +59,8 @@ class CollapsingTransformer(nn.Module): class ConvFormatEmbedding(nn.Module): def __init__(self, *args, **kwargs): super().__init__() - self.emb = nn.Embedding(*args, **kwargs) + # nn.Embedding + self.emb = bnb.nn.StableEmbedding(*args, **kwargs) def forward(self, x): y = self.emb(x) @@ -86,14 +88,14 @@ class CVVP(nn.Module): self.cond_emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim//2, kernel_size=5, stride=2, padding=2), nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1)) self.conditioning_transformer = CollapsingTransformer(model_dim, model_dim, transformer_heads, dropout, conditioning_enc_depth, cond_mask_percentage) - self.to_conditioning_latent = nn.Linear(latent_dim, latent_dim, bias=False) + self.to_conditioning_latent = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False) if mel_codes is None: self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2) else: self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim) self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage) - self.to_speech_latent = nn.Linear(latent_dim, latent_dim, bias=False) + self.to_speech_latent = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False) def get_grad_norm_parameter_groups(self): return { diff --git a/codes/models/clip/mel_text_clip.py b/codes/models/clip/mel_text_clip.py index a8c1dda5..6d55696a 100644 --- a/codes/models/clip/mel_text_clip.py +++ b/codes/models/clip/mel_text_clip.py @@ -7,6 +7,7 @@ from torch import einsum from models.lucidrains.dalle.transformer import Transformer from trainer.networks import register_model from utils.util import opt_get +import bitsandbytes as bnb def exists(val): @@ -45,17 +46,20 @@ class MelTextCLIP(nn.Module): mel_compression=256, ): super().__init__() - self.text_emb = nn.Embedding(num_text_tokens, dim_text) - self.text_pos_emb = nn.Embedding(text_seq_len, dim_text) + # nn.Embedding + self.text_emb = bnb.nn.StableEmbedding(num_text_tokens, dim_text) + # nn.Embedding + self.text_pos_emb = bnb.nn.StableEmbedding(text_seq_len, dim_text) self.text_transformer = Transformer(causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth, heads=text_heads, rotary_emb=False) - self.to_text_latent = nn.Linear(dim_text, dim_latent, bias=False) + self.to_text_latent = bnb.nn.Linear8bitLt(dim_text, dim_latent, bias=False) self.speech_enc = nn.Conv1d(80, dim_speech, kernel_size=3, padding=1) - self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech) + # nn.Embedding + self.speech_pos_emb = bnb.nn.StableEmbedding(num_speech_tokens, dim_speech) self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech, depth=speech_enc_depth, heads=speech_heads, rotary_emb=False) - self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False) + self.to_speech_latent = bnb.nn.Linear8bitLt(dim_speech, dim_latent, bias=False) self.temperature = nn.Parameter(torch.tensor(1.)) self.text_mask_percentage = text_mask_percentage diff --git a/codes/models/clip/text_cond_clip.py b/codes/models/clip/text_cond_clip.py index 39085202..12ebdbc3 100644 --- a/codes/models/clip/text_cond_clip.py +++ b/codes/models/clip/text_cond_clip.py @@ -7,6 +7,7 @@ from models.audio.tts.unified_voice2 import ConditioningEncoder from models.lucidrains.dalle.transformer import Transformer from trainer.networks import register_model from utils.util import opt_get +import bitsandbytes as bnb def exists(val): @@ -45,7 +46,7 @@ class VoiceCondCLIP(nn.Module): self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech) self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech, depth=speech_enc_depth, heads=speech_heads, rotary_emb=False) - self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False) + self.to_speech_latent = bnb.nn.Linear8bitLt(dim_speech, dim_latent, bias=False) self.temperature = nn.Parameter(torch.tensor(1.)) self.voice_mask_percentage = voice_mask_percentage diff --git a/codes/models/clip/text_voice_clip.py b/codes/models/clip/text_voice_clip.py index 297e71dc..0fb1670b 100644 --- a/codes/models/clip/text_voice_clip.py +++ b/codes/models/clip/text_voice_clip.py @@ -11,6 +11,7 @@ from models.audio.tts.unet_diffusion_tts7 import CheckpointedXTransformerEncoder from models.lucidrains.dalle.transformer import Transformer from trainer.networks import register_model from utils.util import opt_get +import bitsandbytes as bnb def exists(val): @@ -53,11 +54,13 @@ class VoiceCLIP(nn.Module): distributed_collect=False, ): super().__init__() - self.text_emb = nn.Embedding(num_text_tokens, dim_text) - self.to_text_latent = nn.Linear(dim_text, dim_latent, bias=False) + # nn.Embedding + self.text_emb = bnb.nn.StableEmbedding(num_text_tokens, dim_text) + self.to_text_latent = bnb.nn.Linear8bitLt(dim_text, dim_latent, bias=False) - self.speech_emb = nn.Embedding(num_speech_tokens, dim_speech) - self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False) + # nn.Embedding + self.speech_emb = bnb.nn.StableEmbedding(num_speech_tokens, dim_speech) + self.to_speech_latent = bnb.nn.Linear8bitLt(dim_speech, dim_latent, bias=False) if use_xformers: self.text_transformer = CheckpointedXTransformerEncoder( @@ -105,8 +108,10 @@ class VoiceCLIP(nn.Module): self.min_mel_size = min_mel_size self.distributed_collect = distributed_collect if not use_xformers: - self.text_pos_emb = nn.Embedding(text_seq_len, dim_text) - self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech) + # nn.Embedding + self.text_pos_emb = bnb.nn.StableEmbedding(text_seq_len, dim_text) + # nn.Embedding + self.speech_pos_emb = bnb.nn.StableEmbedding(num_speech_tokens, dim_speech) def embed_text(self, text): text_mask = torch.ones_like(text.float()).bool() diff --git a/codes/models/diffusion/nn.py b/codes/models/diffusion/nn.py index 50203d75..a40343b7 100644 --- a/codes/models/diffusion/nn.py +++ b/codes/models/diffusion/nn.py @@ -6,6 +6,7 @@ import math import torch as th import torch.nn as nn +import bitsandbytes as bnb # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. @@ -36,7 +37,7 @@ def linear(*args, **kwargs): """ Create a linear module. """ - return nn.Linear(*args, **kwargs) + return bnb.nn.Linear8bitLt(*args, **kwargs) def avg_pool_nd(dims, *args, **kwargs): diff --git a/codes/models/diffusion/rrdb_diffusion.py b/codes/models/diffusion/rrdb_diffusion.py index 9deab2f3..0ca91b5b 100644 --- a/codes/models/diffusion/rrdb_diffusion.py +++ b/codes/models/diffusion/rrdb_diffusion.py @@ -6,6 +6,7 @@ from models.arch_util import ConvGnLelu, default_init_weights, make_layer from models.diffusion.nn import timestep_embedding from trainer.networks import register_model from utils.util import checkpoint +import bitsandbytes as bnb # Conditionally uses torch's checkpoint functionality if it is enabled in the opt file. @@ -28,7 +29,7 @@ class ResidualDenseBlock(nn.Module): self.first_conv = ConvGnLelu(mid_channels, mid_channels, activation=True, norm=False, bias=True) self.emb_layers = nn.Sequential( nn.SiLU(), - nn.Linear( + bnb.nn.Linear8bitLt( mid_channels*4, mid_channels, ), @@ -143,9 +144,9 @@ class RRDBNet(nn.Module): # Guided diffusion uses a time embedding. time_embed_dim = mid_channels * 4 self.time_embed = nn.Sequential( - nn.Linear(mid_channels, time_embed_dim), + bnb.nn.Linear8bitLt(mid_channels, time_embed_dim), nn.SiLU(), - nn.Linear(time_embed_dim, time_embed_dim), + bnb.nn.Linear8bitLt(time_embed_dim, time_embed_dim), ) self.body = make_layer( diff --git a/codes/models/diffusion/unet_diffusion.py b/codes/models/diffusion/unet_diffusion.py index e2ed97b7..ea5ec504 100644 --- a/codes/models/diffusion/unet_diffusion.py +++ b/codes/models/diffusion/unet_diffusion.py @@ -20,6 +20,7 @@ from models.diffusion.nn import ( ) from trainer.networks import register_model from utils.util import checkpoint +import bitsandbytes as bnb class AttentionPool2d(nn.Module): @@ -515,7 +516,8 @@ class UNetModel(nn.Module): ) if self.num_classes is not None: - self.label_emb = nn.Embedding(num_classes, time_embed_dim) + # nn.Embedding + self.label_emb = bnb.nn.StableEmbedding(num_classes, time_embed_dim) self.use_raw_y_as_embedding = use_raw_y_as_embedding assert not ((self.num_classes is not None) and use_raw_y_as_embedding) # These are mutually-exclusive. @@ -867,16 +869,16 @@ class EncoderUNetModel(nn.Module): ) elif pool == "spatial": self.out = nn.Sequential( - nn.Linear(self._feature_size, 2048), + bnb.nn.Linear8bitLt(self._feature_size, 2048), nn.ReLU(), - nn.Linear(2048, self.out_channels), + bnb.nn.Linear8bitLt(2048, self.out_channels), ) elif pool == "spatial_v2": self.out = nn.Sequential( - nn.Linear(self._feature_size, 2048), + bnb.nn.Linear8bitLt(self._feature_size, 2048), normalization(2048), nn.SiLU(), - nn.Linear(2048, self.out_channels), + bnb.nn.Linear8bitLt(2048, self.out_channels), ) else: raise NotImplementedError(f"Unexpected {pool} pooling") diff --git a/codes/models/diffusion/unet_latent_guide.py b/codes/models/diffusion/unet_latent_guide.py index 41dd85a8..52272141 100644 --- a/codes/models/diffusion/unet_latent_guide.py +++ b/codes/models/diffusion/unet_latent_guide.py @@ -26,6 +26,7 @@ from models.diffusion.nn import ( ) from trainer.networks import register_model from utils.util import checkpoint +import bitsandbytes as bnb class AttentionPool2d(nn.Module): @@ -476,7 +477,8 @@ class UNetModel(nn.Module): ) if self.num_classes is not None: - self.label_emb = nn.Embedding(num_classes, time_embed_dim) + # nn.Embedding + self.label_emb = bnb.nn.StableEmbedding(num_classes, time_embed_dim) self.input_blocks = nn.ModuleList( [ @@ -736,7 +738,7 @@ class ResNetEncoder(nn.Module): dilate=replace_stride_with_dilation[2]) f=512 self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - self.fc = nn.Linear(f * block.expansion, output_dim) + self.fc = bnb.nn.Linear8bitLt(f * block.expansion, output_dim) for m in self.modules(): if isinstance(m, nn.Conv2d): diff --git a/codes/models/image_generation/discriminator_vgg_arch.py b/codes/models/image_generation/discriminator_vgg_arch.py index 234272de..05ea3e8b 100644 --- a/codes/models/image_generation/discriminator_vgg_arch.py +++ b/codes/models/image_generation/discriminator_vgg_arch.py @@ -5,6 +5,7 @@ import torch.nn.functional as F from trainer.networks import register_model from utils.util import checkpoint, opt_get +import bitsandbytes as bnb class Discriminator_VGG_128(nn.Module): @@ -46,8 +47,8 @@ class Discriminator_VGG_128(nn.Module): input_img_factor = input_img_factor // 2 final_nf = nf * 16 - self.linear1 = nn.Linear(final_nf * 4 * input_img_factor * 4 * input_img_factor, 100) - self.linear2 = nn.Linear(100, 1) + self.linear1 = bnb.nn.Linear8bitLt(final_nf * 4 * input_img_factor * 4 * input_img_factor, 100) + self.linear2 = bnb.nn.Linear8bitLt(100, 1) # activation function self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) @@ -129,8 +130,8 @@ class Discriminator_VGG_128_GN(nn.Module): # activation function self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - self.linear1 = nn.Linear(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 100) - self.linear2 = nn.Linear(100, 1) + self.linear1 = bnb.nn.Linear8bitLt(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 100) + self.linear2 = bnb.nn.Linear8bitLt(100, 1) def compute_body(self, x): fea = self.lrelu(self.conv0_0(x)) @@ -219,8 +220,8 @@ class DiscriminatorVGG448GN(nn.Module): self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) final_nf = nf * 8 - self.linear1 = nn.Linear(int(final_nf * 7 * 7), 100) - self.linear2 = nn.Linear(100, 1) + self.linear1 = bnb.nn.Linear8bitLt(int(final_nf * 7 * 7), 100) + self.linear2 = bnb.nn.Linear8bitLt(100, 1) # Assign all new heads to the new param group.2 for m in [self.convn1_0, self.convn1_1, self.bnn1_1, self.conv0_0_new, self.bn0_0]: diff --git a/codes/models/image_generation/srflow/module_util.py b/codes/models/image_generation/srflow/module_util.py index ca5d7fa9..032e3af3 100644 --- a/codes/models/image_generation/srflow/module_util.py +++ b/codes/models/image_generation/srflow/module_util.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn import torch.nn.init as init import torch.nn.functional as F +import bitsandbytes as bnb def initialize_weights(net_l, scale=1): @@ -14,7 +15,7 @@ def initialize_weights(net_l, scale=1): m.weight.data *= scale # for residual block if m.bias is not None: m.bias.data.zero_() - elif isinstance(m, nn.Linear): + elif isinstance(m, bnb.nn.Linear8bitLt): init.kaiming_normal_(m.weight, a=0, mode='fan_in') m.weight.data *= scale if m.bias is not None: diff --git a/codes/models/image_generation/stylegan/stylegan2_lucidrains.py b/codes/models/image_generation/stylegan/stylegan2_lucidrains.py index 0f65a7ef..bb72d90c 100644 --- a/codes/models/image_generation/stylegan/stylegan2_lucidrains.py +++ b/codes/models/image_generation/stylegan/stylegan2_lucidrains.py @@ -28,6 +28,7 @@ except: APEX_AVAILABLE = False assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.' +import bitsandbytes as bnb num_cores = multiprocessing.cpu_count() @@ -351,7 +352,7 @@ class RGBBlock(nn.Module): def __init__(self, latent_dim, input_channel, upsample, rgba=False): super().__init__() self.input_channel = input_channel - self.to_style = nn.Linear(latent_dim, input_channel) + self.to_style = bnb.nn.Linear8bitLt(latent_dim, input_channel) out_filters = 3 if not rgba else 4 self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False) @@ -489,16 +490,16 @@ class GeneratorBlockWithStructure(nn.Module): # Uses stylegan1 style blocks for injecting structural latent. self.conv0 = EqualConv2d(input_channels, filters, 3, padding=1) - self.to_noise0 = nn.Linear(1, filters) + self.to_noise0 = bnb.nn.Linear8bitLt(1, filters) self.noise0 = equal_lr(NoiseInjection(filters)) self.adain0 = AdaptiveInstanceNorm(filters, latent_dim) - self.to_style1 = nn.Linear(latent_dim, filters) - self.to_noise1 = nn.Linear(1, filters) + self.to_style1 = bnb.nn.Linear8bitLt(latent_dim, filters) + self.to_noise1 = bnb.nn.Linear8bitLt(1, filters) self.conv1 = Conv2DMod(filters, filters, 3) - self.to_style2 = nn.Linear(latent_dim, filters) - self.to_noise2 = nn.Linear(1, filters) + self.to_style2 = bnb.nn.Linear8bitLt(latent_dim, filters) + self.to_noise2 = bnb.nn.Linear8bitLt(1, filters) self.conv2 = Conv2DMod(filters, filters, 3) self.activation = leaky_relu() @@ -540,12 +541,12 @@ class GeneratorBlock(nn.Module): self.structure_conv = nn.Conv2d(3, input_channels, 3, padding=1) input_channels = input_channels * 2 - self.to_style1 = nn.Linear(latent_dim, input_channels) - self.to_noise1 = nn.Linear(1, filters) + self.to_style1 = bnb.nn.Linear8bitLt(latent_dim, input_channels) + self.to_noise1 = bnb.nn.Linear8bitLt(1, filters) self.conv1 = Conv2DMod(input_channels, filters, 3) - self.to_style2 = nn.Linear(latent_dim, filters) - self.to_noise2 = nn.Linear(1, filters) + self.to_style2 = bnb.nn.Linear8bitLt(latent_dim, filters) + self.to_noise2 = bnb.nn.Linear8bitLt(1, filters) self.conv2 = Conv2DMod(filters, filters, 3) self.activation = leaky_relu() @@ -724,7 +725,7 @@ class StyleGan2GeneratorWithLatent(nn.Module): def _init_weights(self): for m in self.modules(): - if type(m) in {nn.Conv2d, nn.Linear} and hasattr(m, 'weight'): + if type(m) in {nn.Conv2d, bnb.nn.Linear8bitLt} and hasattr(m, 'weight'): nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') for block in self.gen.blocks: @@ -804,7 +805,7 @@ class StyleGan2Discriminator(nn.Module): self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1) self.flatten = Flatten() - self.to_logit = nn.Linear(latent_dim, 1) + self.to_logit = bnb.nn.Linear8bitLt(latent_dim, 1) self._init_weights() @@ -836,7 +837,7 @@ class StyleGan2Discriminator(nn.Module): def _init_weights(self): for m in self.modules(): - if type(m) in {nn.Conv2d, nn.Linear}: + if type(m) in {nn.Conv2d, bnb.nn.Linear8bitLt}: nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') diff --git a/codes/models/image_latents/byol/byol_model_wrapper.py b/codes/models/image_latents/byol/byol_model_wrapper.py index cf3b6ea0..105d7ec7 100644 --- a/codes/models/image_latents/byol/byol_model_wrapper.py +++ b/codes/models/image_latents/byol/byol_model_wrapper.py @@ -12,6 +12,7 @@ from torch import nn from data.images.byol_attachment import RandomApply from trainer.networks import register_model, create_model from utils.util import checkpoint, opt_get +import bitsandbytes as bnb def default(val, def_val): @@ -78,10 +79,10 @@ class MLP(nn.Module): def __init__(self, dim, projection_size, hidden_size=4096): super().__init__() self.net = nn.Sequential( - nn.Linear(dim, hidden_size), + bnb.nn.Linear8bitLt(dim, hidden_size), nn.BatchNorm1d(hidden_size), nn.ReLU(inplace=True), - nn.Linear(hidden_size, projection_size) + bnb.nn.Linear8bitLt(hidden_size, projection_size) ) def forward(self, x): @@ -103,10 +104,10 @@ class StructuralMLP(nn.Module): nn.BatchNorm2d(c), nn.ReLU(inplace=True), nn.Flatten(), - nn.Linear(flattened_dim, hidden_size), + bnb.nn.Linear8bitLt(flattened_dim, hidden_size), nn.BatchNorm1d(hidden_size), nn.ReLU(inplace=True), - nn.Linear(hidden_size, projection_size) + bnb.nn.Linear8bitLt(hidden_size, projection_size) ) def forward(self, x): diff --git a/codes/models/image_latents/fixup_resnet/DiscriminatorResnet_arch.py b/codes/models/image_latents/fixup_resnet/DiscriminatorResnet_arch.py index 6991b0fc..2513ba94 100644 --- a/codes/models/image_latents/fixup_resnet/DiscriminatorResnet_arch.py +++ b/codes/models/image_latents/fixup_resnet/DiscriminatorResnet_arch.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn import numpy as np +import bitsandbytes as bnb __all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152'] @@ -108,8 +109,8 @@ class FixupResNet(nn.Module): self.layer4 = self._make_layer(block, num_filters*8, layers[3], stride=2) self.bias2 = nn.Parameter(torch.zeros(1)) reduced_img_sz = int(input_img_size / 32) - self.fc1 = nn.Linear(num_filters * 8 * reduced_img_sz * reduced_img_sz, 100) - self.fc2 = nn.Linear(100, num_classes) + self.fc1 = bnb.nn.Linear8bitLt(num_filters * 8 * reduced_img_sz * reduced_img_sz, 100) + self.fc2 = bnb.nn.Linear8bitLt(100, num_classes) for m in self.modules(): if isinstance(m, FixupBasicBlock): @@ -124,7 +125,7 @@ class FixupResNet(nn.Module): if m.downsample is not None: nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:])))) ''' - elif isinstance(m, nn.Linear): + elif isinstance(m, bnb.nn.Linear8bitLt): nn.init.constant_(m.weight, 0) nn.init.constant_(m.bias, 0)''' diff --git a/codes/models/image_latents/vit_latent.py b/codes/models/image_latents/vit_latent.py index 0243e8a7..3f45a5c8 100644 --- a/codes/models/image_latents/vit_latent.py +++ b/codes/models/image_latents/vit_latent.py @@ -5,6 +5,7 @@ import torch.nn.functional as F from models.arch_util import ResBlock from models.lucidrains.x_transformers import Encoder from trainer.networks import register_model +import bitsandbytes as bnb class VitLatent(nn.Module): @@ -31,10 +32,10 @@ class VitLatent(nn.Module): do_checkpointing=True ) - self.mlp = nn.Sequential(nn.Linear(hidden_dim, hidden_dim*2), + self.mlp = nn.Sequential(bnb.nn.Linear8bitLt(hidden_dim, hidden_dim*2), nn.BatchNorm1d(hidden_dim*2), nn.ReLU(inplace=True), - nn.Linear(hidden_dim*2, hidden_dim)) + bnb.nn.Linear8bitLt(hidden_dim*2, hidden_dim)) def provide_ema(self, ema): self.ema = ema diff --git a/codes/models/lucidrains/dalle/attention.py b/codes/models/lucidrains/dalle/attention.py index c3b52cc4..3662dbcb 100644 --- a/codes/models/lucidrains/dalle/attention.py +++ b/codes/models/lucidrains/dalle/attention.py @@ -7,6 +7,7 @@ import torch.nn.functional as F from einops import rearrange, repeat from rotary_embedding_torch import apply_rotary_emb +import bitsandbytes as bnb # helpers @@ -47,9 +48,9 @@ class Attention(nn.Module): self.stable = stable self.causal = causal - self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + self.to_qkv = bnb.nn.Linear8bitLt(dim, inner_dim * 3, bias = False) self.to_out = nn.Sequential( - nn.Linear(inner_dim, dim), + bnb.nn.Linear8bitLt(inner_dim, dim), nn.Dropout(dropout) ) @@ -102,10 +103,10 @@ class SparseConvCausalAttention(nn.Module): self.stable = stable - self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + self.to_qkv = bnb.nn.Linear8bitLt(dim, inner_dim * 3, bias = False) self.to_out = nn.Sequential( - nn.Linear(inner_dim, dim), + bnb.nn.Linear8bitLt(inner_dim, dim), nn.Dropout(dropout) ) @@ -222,10 +223,10 @@ class SparseAxialCausalAttention(nn.Module): self.stable = stable - self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + self.to_qkv = bnb.nn.Linear8bitLt(dim, inner_dim * 3, bias = False) self.to_out = nn.Sequential( - nn.Linear(inner_dim, dim), + bnb.nn.Linear8bitLt(inner_dim, dim), nn.Dropout(dropout) ) diff --git a/codes/models/lucidrains/dalle/transformer.py b/codes/models/lucidrains/dalle/transformer.py index 389e6849..e3ad4769 100644 --- a/codes/models/lucidrains/dalle/transformer.py +++ b/codes/models/lucidrains/dalle/transformer.py @@ -11,6 +11,7 @@ from models.lucidrains.dalle.attention import Attention, SparseAttention, Sparse from rotary_embedding_torch import RotaryEmbedding, broadcat from g_mlp_pytorch import gMLPBlock +import bitsandbytes as bnb # helpers @@ -78,10 +79,10 @@ class FeedForward(nn.Module): def __init__(self, dim, dropout = 0., mult = 4.): super().__init__() self.net = nn.Sequential( - nn.Linear(dim, dim * mult * 2), + bnb.nn.Linear8bitLt(dim, dim * mult * 2), GEGLU(), nn.Dropout(dropout), - nn.Linear(dim * mult, dim) + bnb.nn.Linear8bitLt(dim * mult, dim) ) def forward(self, x): diff --git a/codes/models/lucidrains/performer/performer_pytorch.py b/codes/models/lucidrains/performer/performer_pytorch.py index 2d618f11..db46663d 100644 --- a/codes/models/lucidrains/performer/performer_pytorch.py +++ b/codes/models/lucidrains/performer/performer_pytorch.py @@ -21,6 +21,7 @@ try: APEX_AVAILABLE = True except: APEX_AVAILABLE = False +import bitsandbytes as bnb # helpers @@ -356,10 +357,10 @@ class FeedForward(nn.Module): activation = default(activation, nn.GELU) self.glu = glu - self.w1 = nn.Linear(dim, dim * mult * (2 if glu else 1)) + self.w1 = bnb.nn.Linear8bitLt(dim, dim * mult * (2 if glu else 1)) self.act = activation() self.dropout = nn.Dropout(dropout) - self.w2 = nn.Linear(dim * mult, dim) + self.w2 = bnb.nn.Linear8bitLt(dim * mult, dim) def forward(self, x, **kwargs): if not self.glu: @@ -401,10 +402,10 @@ class Attention(nn.Module): self.global_heads = heads - local_heads self.local_attn = LocalAttention(window_size = local_window_size, causal = causal, autopad = True, dropout = dropout, look_forward = int(not causal), rel_pos_emb_config = (dim_head, local_heads)) if local_heads > 0 else None - self.to_q = nn.Linear(dim, inner_dim, bias = qkv_bias) - self.to_k = nn.Linear(dim, inner_dim, bias = qkv_bias) - self.to_v = nn.Linear(dim, inner_dim, bias = qkv_bias) - self.to_out = nn.Linear(inner_dim, dim, bias = attn_out_bias) + self.to_q = bnb.nn.Linear8bitLt(dim, inner_dim, bias = qkv_bias) + self.to_k = bnb.nn.Linear8bitLt(dim, inner_dim, bias = qkv_bias) + self.to_v = bnb.nn.Linear8bitLt(dim, inner_dim, bias = qkv_bias) + self.to_out = bnb.nn.Linear8bitLt(inner_dim, dim, bias = attn_out_bias) self.dropout = nn.Dropout(dropout) def forward(self, x, pos_emb = None, context = None, mask = None, context_mask = None, **kwargs): @@ -458,7 +459,8 @@ class CrossAttention(Attention): class AbsolutePositionalEmbedding(nn.Module): def __init__(self, dim, max_seq_len): super().__init__() - self.emb = nn.Embedding(max_seq_len, dim) + # nn.Embedding + self.emb = bnb.nn.StableEmbedding(max_seq_len, dim) def forward(self, x): t = torch.arange(x.shape[1], device=x.device) @@ -619,7 +621,8 @@ class PerformerLM(nn.Module): local_attn_heads = cast_tuple(local_attn_heads) self.max_seq_len = max_seq_len - self.token_emb = nn.Embedding(num_tokens, dim) + # nn.Embedding + self.token_emb = bnb.nn.StableEmbedding(num_tokens, dim) if rotary_position_emb: self.pos_emb = FixedPositionalEmbedding(dim, max_seq_len) @@ -636,7 +639,7 @@ class PerformerLM(nn.Module): self.performer = Performer(dim, depth, heads, dim_head, local_attn_heads, local_window_size, causal, ff_mult, nb_features, feature_redraw_interval, reversible, ff_chunks, generalized_attention, kernel_fn, use_scalenorm, use_rezero, ff_glu, ff_dropout, attn_dropout, cross_attend, no_projection, auto_check_redraw, qkv_bias, attn_out_bias, shift_tokens) self.norm = nn.LayerNorm(dim) - self.to_out = nn.Linear(dim, num_tokens) if not tie_embed else None + self.to_out = bnb.nn.Linear8bitLt(dim, num_tokens) if not tie_embed else None def check_redraw_projections(self): self.performer.check_redraw_projections() diff --git a/codes/models/lucidrains/vq.py b/codes/models/lucidrains/vq.py index 4b1019ec..058e47fb 100644 --- a/codes/models/lucidrains/vq.py +++ b/codes/models/lucidrains/vq.py @@ -8,6 +8,7 @@ from torch.cuda.amp import autocast from einops import rearrange, repeat from contextlib import contextmanager +import bitsandbytes as bnb def par(t, nm): @@ -355,9 +356,9 @@ class VectorQuantize(nn.Module): codebook_dim = default(codebook_dim, dim) requires_projection = codebook_dim != dim - self.project_in = nn.Linear(dim, codebook_dim) if requires_projection \ + self.project_in = bnb.nn.Linear8bitLt(dim, codebook_dim) if requires_projection \ else nn.Identity() - self.project_out = nn.Linear(codebook_dim, dim) if requires_projection \ + self.project_out = bnb.nn.Linear8bitLt(codebook_dim, dim) if requires_projection \ else nn.Identity() self.eps = eps diff --git a/codes/models/lucidrains/x_transformers.py b/codes/models/lucidrains/x_transformers.py index 32ad19e7..d51618a2 100644 --- a/codes/models/lucidrains/x_transformers.py +++ b/codes/models/lucidrains/x_transformers.py @@ -11,6 +11,7 @@ from einops import rearrange, repeat, reduce from einops.layers.torch import Rearrange from torch.utils.checkpoint import checkpoint +import bitsandbytes as bnb DEFAULT_DIM_HEAD = 64 @@ -125,7 +126,8 @@ class AbsolutePositionalEmbedding(nn.Module): def __init__(self, dim, max_seq_len): super().__init__() self.scale = dim ** -0.5 - self.emb = nn.Embedding(max_seq_len, dim) + # nn.Embedding + self.emb = bnb.nn.StableEmbedding(max_seq_len, dim) def forward(self, x): n = torch.arange(x.shape[1], device=x.device) @@ -154,7 +156,8 @@ class RelativePositionBias(nn.Module): self.causal = causal self.num_buckets = num_buckets self.max_distance = max_distance - self.relative_attention_bias = nn.Embedding(num_buckets, heads) + # nn.Embedding + self.relative_attention_bias = bnb.nn.StableEmbedding(num_buckets, heads) @staticmethod def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128): @@ -360,7 +363,7 @@ class RMSScaleShiftNorm(nn.Module): self.cdim = 1 self.pdim = -1 else: - self.scale_shift_process = nn.Linear(embed_dim, dim * 2, bias=bias) + self.scale_shift_process = bnb.nn.Linear8bitLt(embed_dim, dim * 2, bias=bias) self.cdim = -1 self.pdim = 1 @@ -447,7 +450,7 @@ class GLU(nn.Module): def __init__(self, dim_in, dim_out, activation): super().__init__() self.act = activation - self.proj = nn.Linear(dim_in, dim_out * 2) + self.proj = bnb.nn.Linear8bitLt(dim_in, dim_out * 2) def forward(self, x): x, gate = self.proj(x).chunk(2, dim=-1) @@ -472,7 +475,7 @@ class FeedForward(nn.Module): activation = ReluSquared() if relu_squared else nn.GELU() project_in = nn.Sequential( - nn.Linear(dim, inner_dim), + bnb.nn.Linear8bitLt(dim, inner_dim), activation ) if not glu else GLU(dim, inner_dim, activation) @@ -480,7 +483,7 @@ class FeedForward(nn.Module): project_in, nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(), nn.Dropout(dropout), - nn.Linear(inner_dim, dim_out) + bnb.nn.Linear8bitLt(inner_dim, dim_out) ) # init last linear layer to 0 @@ -535,16 +538,16 @@ class Attention(nn.Module): qk_dim = int(collab_compression * qk_dim) self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim)) - self.to_q = nn.Linear(dim, qk_dim, bias=False) - self.to_k = nn.Linear(dim, qk_dim, bias=False) - self.to_v = nn.Linear(dim, v_dim, bias=False) + self.to_q = bnb.nn.Linear8bitLt(dim, qk_dim, bias=False) + self.to_k = bnb.nn.Linear8bitLt(dim, qk_dim, bias=False) + self.to_v = bnb.nn.Linear8bitLt(dim, v_dim, bias=False) self.dropout = nn.Dropout(dropout) # add GLU gating for aggregated values, from alphafold2 self.to_v_gate = None if gate_values: - self.to_v_gate = nn.Linear(dim, v_dim) + self.to_v_gate = bnb.nn.Linear8bitLt(dim, v_dim) nn.init.constant_(self.to_v_gate.weight, 0) nn.init.constant_(self.to_v_gate.bias, 1) @@ -581,7 +584,7 @@ class Attention(nn.Module): # attention on attention self.attn_on_attn = on_attn out_dim = default(out_dim, dim) - self.to_out = nn.Sequential(nn.Linear(v_dim, out_dim * 2), nn.GLU()) if on_attn else nn.Linear(v_dim, out_dim) + self.to_out = nn.Sequential(bnb.nn.Linear8bitLt(v_dim, out_dim * 2), nn.GLU()) if on_attn else bnb.nn.Linear8bitLt(v_dim, out_dim) self.rel_pos_bias = rel_pos_bias if rel_pos_bias: @@ -1077,7 +1080,7 @@ class ViTransformerWrapper(nn.Module): self.patch_size = patch_size self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) - self.patch_to_embedding = nn.Linear(patch_dim, dim) + self.patch_to_embedding = bnb.nn.Linear8bitLt(patch_dim, dim) self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) self.dropout = nn.Dropout(emb_dropout) @@ -1135,18 +1138,19 @@ class TransformerWrapper(nn.Module): self.max_mem_len = max_mem_len self.shift_mem_down = shift_mem_down - self.token_emb = nn.Embedding(num_tokens, emb_dim) + # nn.Embedding + self.token_emb = bnb.nn.StableEmbedding(num_tokens, emb_dim) self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( use_pos_emb and not attn_layers.has_pos_emb) else always(0) self.emb_dropout = nn.Dropout(emb_dropout) - self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.project_emb = bnb.nn.Linear8bitLt(emb_dim, dim) if emb_dim != dim else nn.Identity() self.attn_layers = attn_layers self.norm = nn.LayerNorm(dim) self.init_() - self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() + self.to_logits = bnb.nn.Linear8bitLt(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() # memory tokens (like [cls]) from Memory Transformers paper num_memory_tokens = default(num_memory_tokens, 0) @@ -1233,12 +1237,12 @@ class ContinuousTransformerWrapper(nn.Module): use_pos_emb and not attn_layers.has_pos_emb) else always(0) self.emb_dropout = nn.Dropout(emb_dropout) - self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity() + self.project_in = bnb.nn.Linear8bitLt(dim_in, dim) if exists(dim_in) else nn.Identity() self.attn_layers = attn_layers self.norm = nn.LayerNorm(dim) - self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity() + self.project_out = bnb.nn.Linear8bitLt(dim, dim_out) if exists(dim_out) else nn.Identity() def forward( self, diff --git a/codes/models/vqvae/gumbel_quantizer.py b/codes/models/vqvae/gumbel_quantizer.py index 2d71ec7f..8b37d1ef 100644 --- a/codes/models/vqvae/gumbel_quantizer.py +++ b/codes/models/vqvae/gumbel_quantizer.py @@ -4,13 +4,15 @@ import torch.nn.functional as F from torch import einsum from utils.weight_scheduler import LinearDecayWeightScheduler +import bitsandbytes as bnb class GumbelQuantizer(nn.Module): def __init__(self, inp_dim, codebook_dim, num_tokens, straight_through=False): super().__init__() self.to_logits = nn.Conv1d(inp_dim, num_tokens, 1) - self.codebook = nn.Embedding(num_tokens, codebook_dim) + # nn.Embedding + self.codebook = bnb.nn.StableEmbedding(num_tokens, codebook_dim) self.straight_through = straight_through self.temperature_scheduler = LinearDecayWeightScheduler(10, 5000, .9, 2000) self.step = 0 diff --git a/codes/models/vqvae/vector_quantizer.py b/codes/models/vqvae/vector_quantizer.py index 96015b02..1932422b 100644 --- a/codes/models/vqvae/vector_quantizer.py +++ b/codes/models/vqvae/vector_quantizer.py @@ -4,6 +4,7 @@ import torch.nn.functional as F from einops import rearrange, repeat from models.arch_util import l2norm, sample_vectors, default, ema_inplace +import bitsandbytes as bnb def kmeans(samples, num_clusters, num_iters = 10, use_cosine_sim = False): @@ -184,8 +185,8 @@ class VectorQuantize(nn.Module): codebook_dim = default(codebook_dim, dim) requires_projection = codebook_dim != dim - self.project_in = nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity() - self.project_out = nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity() + self.project_in = bnb.nn.Linear8bitLt(dim, codebook_dim) if requires_projection else nn.Identity() + self.project_out = bnb.nn.Linear8bitLt(codebook_dim, dim) if requires_projection else nn.Identity() self.eps = eps diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 3ee81074..8c5b43ce 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -21,6 +21,7 @@ import torchvision.utils as utils from utils.loss_accumulator import LossAccumulator, InfStorageLossAccumulator from utils.util import opt_get, denormalize +import bitsandbytes as bnb logger = logging.getLogger('base') @@ -337,7 +338,7 @@ class ExtensibleTrainer(BaseModel): for net in self.networks.values(): for mod in net.modules(): fan_in = -1 - if isinstance(mod, nn.Linear): + if isinstance(mod, bnb.nn.Linear8bitLt): fan_in = mod.weight.data.shape[1] elif isinstance(mod, nn.Conv1d): fan_in = mod.weight.data.shape[0] diff --git a/codes/trainer/feature_model.py b/codes/trainer/feature_model.py index 12e31cb1..ea7befd5 100644 --- a/codes/trainer/feature_model.py +++ b/codes/trainer/feature_model.py @@ -1,3 +1,4 @@ + import logging from collections import OrderedDict @@ -6,6 +7,7 @@ import torch.nn as nn import trainer.networks as networks import trainer.lr_scheduler as lr_scheduler from .base_model import BaseModel +import bitsandbytes as bnb logger = logging.getLogger('base') @@ -40,7 +42,8 @@ class FeatureModel(BaseModel): else: if self.rank <= 0: logger.warning('Params [{:s}] will not optimize.'.format(k)) - self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], + # torch.optim.Adam + self.optimizer_G = bnb.optim.Adam8bit(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1_G'], train_opt['beta2_G'])) self.optimizers.append(self.optimizer_G) diff --git a/codes/trainer/lr_scheduler.py b/codes/trainer/lr_scheduler.py index f437ff3c..b22e0d91 100644 --- a/codes/trainer/lr_scheduler.py +++ b/codes/trainer/lr_scheduler.py @@ -3,10 +3,10 @@ from collections import Counter from collections import defaultdict import torch from torch.optim.lr_scheduler import _LRScheduler +import bitsandbytes as bnb from utils.util import opt_get - def get_scheduler_for_name(name, optimizers, scheduler_opt): schedulers = [] for o in optimizers: @@ -136,7 +136,8 @@ class CosineAnnealingLR_Restart(_LRScheduler): if __name__ == "__main__": - optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=1e-4, weight_decay=0, + #torch.optim.Adam + optimizer = bnb.optim.Adam8bit([torch.zeros(3, 64, 3, 3)], lr=1e-4, weight_decay=0, betas=(0.9, 0.99)) ############################## # MultiStepLR_Restart diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index 9ec2f78b..37cac2f3 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -12,6 +12,7 @@ from utils.util import recursively_detach, opt_get, clip_grad_norm logger = logging.getLogger('base') +import bitsandbytes as bnb # Defines the expected API for a single training step class ConfigurableStep(Module): @@ -82,7 +83,8 @@ class ConfigurableStep(Module): import torch.nn as nn norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d, nn.BatchNorm3d, nn.InstanceNorm3d, nn.GroupNorm, nn.LayerNorm) - emb_modules = (nn.Embedding, nn.EmbeddingBag) + # nn.Embedding + emb_modules = (bnb.nn.StableEmbedding, nn.EmbeddingBag) param_names_notweights = set() all_param_names = set() param_map = {} @@ -123,7 +125,8 @@ class ConfigurableStep(Module): { 'params': params_weights, 'weight_decay': opt_get(opt_config, ['weight_decay'], 0) }, { 'params': params_notweights, 'weight_decay': 0 } ] - opt = torch.optim.AdamW(groups, lr=opt_config['lr'], + # torch.optim.AdamW + opt = bnb.optim.AdamW8bit(groups, lr=opt_config['lr'], weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2), betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999))) opt._group_names = [params_names_weights, params_names_notweights] @@ -141,14 +144,16 @@ class ConfigurableStep(Module): # The torch ZeRO implementation does not seem to support parameter groups, so do not shard the non-weighted # parameters and just use a normal AdamW implementation. In a large network, these weights will normally # be a tiny fraction of the total weights. - opt_unweighted = torch.optim.AdamW(params_notweights, lr=opt_config['lr'], weight_decay=0, + # torch.optim.AdamW + opt_unweighted = bnb.optim.AdamW8bit(params_notweights, lr=opt_config['lr'], weight_decay=0, betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999))) opt_unweighted._config = opt_config opt_unweighted._config['network'] = net_name opt_unweighted._group_names = [] self.optimizers.append(opt_unweighted) - opt = ZeroRedundancyOptimizer(params_weights, optimizer_class=torch.optim.AdamW, lr=opt_config['lr'], + # torch.optim.AdamW + opt = ZeroRedundancyOptimizer(params_weights, optimizer_class=bnb.optim.AdamW8bit, lr=opt_config['lr'], weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2), betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999))) opt.param_groups[0]['initial_lr'] = opt_config['lr'] @@ -162,7 +167,8 @@ class ConfigurableStep(Module): opt._group_names = sorted(list(all_param_names)) elif self.step_opt['optimizer'] == 'lamb': from trainer.optimizers.lamb import Lamb - opt_unweighted = torch.optim.AdamW(params_notweights, lr=opt_config['lr'], weight_decay=0, + # torch.optim.AdamW + opt_unweighted = bnb.optim.AdamW8bit(params_notweights, lr=opt_config['lr'], weight_decay=0, betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999))) opt_unweighted._config = opt_config opt_unweighted._config['network'] = net_name