bitsandbytes #2

Merged
mrq merged 3 commits from bitsandbytes into master 2023-02-23 03:16:26 +00:00
67 changed files with 342 additions and 211 deletions
Showing only changes of commit 4427d7fb84 - Show all commits

View File

@ -9,6 +9,7 @@ import torch.nn.utils.spectral_norm as SpectralNorm
from math import sqrt
from utils.util import checkpoint
import bitsandbytes as bnb
def exists(val):
@ -73,7 +74,7 @@ def initialize_weights(net_l, scale=1):
m.weight.data *= scale # for residual block
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
elif isinstance(m, bnb.nn.Linear8bitLt):
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale
if m.bias is not None:
@ -108,7 +109,7 @@ def default_init_weights(module, scale=1):
if isinstance(m, nn.Conv2d):
kaiming_init(m, a=0, mode='fan_in', bias=0)
m.weight.data *= scale
elif isinstance(m, nn.Linear):
elif isinstance(m, bnb.nn.Linear8bitLt):
kaiming_init(m, a=0, mode='fan_in', bias=0)
m.weight.data *= scale
@ -141,7 +142,7 @@ def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
return bnb.nn.Linear8bitLt(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):

View File

@ -9,6 +9,7 @@ from data.audio.unsupervised_audio_dataset import load_audio
from models.audio.tts.tacotron2.text import sequence_to_text
from trainer.networks import register_model
from utils.util import opt_get
import bitsandbytes as bnb
def only_letters(string):
@ -51,7 +52,7 @@ class Wav2VecWrapper(nn.Module):
self.w2v = Wav2Vec2ForCTC.from_pretrained(basis_model)
# Perform some surgery to get the model we actually want.
self.w2v.wav2vec2.encoder.gradient_checkpointing = checkpointing_enabled
self.w2v.lm_head = nn.Linear(self.w2v.config.hidden_size, vocab_size)
self.w2v.lm_head = bnb.nn.Linear8bitLt(self.w2v.config.hidden_size, vocab_size)
self.w2v.config.vocab_size = vocab_size
self.w2v.config.pad_token_id = 0
self.w2v.config.ctc_loss_reduction = 'sum'

View File

@ -5,6 +5,7 @@ import torch.nn as nn
from trainer.networks import register_model
from utils.util import opt_get
from typing import Type, Any, Callable, Union, List, Optional
import bitsandbytes as bnb
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
@ -172,7 +173,7 @@ class ResNet(nn.Module):
self.layer4 = self._make_layer(block, 512, layers[3], stride=4,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.fc = nn.Linear(512 * block.expansion, num_classes)
self.fc = bnb.nn.Linear8bitLt(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv1d):

View File

@ -15,13 +15,14 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
from models.arch_util import ResBlock
from trainer.networks import register_model
from utils.util import checkpoint
import bitsandbytes as bnb
class Mel2Vec2FeatureProjection(nn.Module):
def __init__(self, inner_dim, dropout):
super().__init__()
self.layer_norm = nn.LayerNorm(inner_dim, eps=1e-5)
self.projection = nn.Linear(inner_dim, inner_dim)
self.projection = bnb.nn.Linear8bitLt(inner_dim, inner_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, hidden_states):
@ -58,10 +59,10 @@ class Wav2Vec2Attention(nn.Module):
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.k_proj = bnb.nn.Linear8bitLt(embed_dim, embed_dim, bias=bias)
self.v_proj = bnb.nn.Linear8bitLt(embed_dim, embed_dim, bias=bias)
self.q_proj = bnb.nn.Linear8bitLt(embed_dim, embed_dim, bias=bias)
self.out_proj = bnb.nn.Linear8bitLt(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
@ -182,10 +183,10 @@ class Wav2Vec2FeedForward(nn.Module):
super().__init__()
self.intermediate_dropout = nn.Dropout(dropout)
self.intermediate_dense = nn.Linear(hidden_size, intermediate_size)
self.intermediate_dense = bnb.nn.Linear8bitLt(hidden_size, intermediate_size)
self.intermediate_act_fn = F.gelu
self.output_dense = nn.Linear(intermediate_size, hidden_size)
self.output_dense = bnb.nn.Linear8bitLt(intermediate_size, hidden_size)
self.output_dropout = nn.Dropout(dropout)
def forward(self, hidden_states):
@ -429,7 +430,7 @@ class Mel2Vec(nn.Module):
k = math.sqrt(1 / module.projection.in_features)
nn.init.uniform_(module.projection.weight, a=-k, b=k)
nn.init.uniform_(module.projection.bias, a=-k, b=k)
elif isinstance(module, nn.Linear):
elif isinstance(module, bnb.nn.Linear8bitLt):
if self.disable_custom_linear_init:
return
module.weight.data.normal_(mean=0.0, std=self.linear_init_scale)
@ -510,7 +511,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
self.codevectors = nn.Parameter(
torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups)
)
self.weight_proj = nn.Linear(proj_dim, self.num_groups * self.num_vars)
self.weight_proj = bnb.nn.Linear8bitLt(proj_dim, self.num_groups * self.num_vars)
# can be decayed for training
self.temperature = 2
@ -606,8 +607,8 @@ class ContrastiveTrainingWrapper(nn.Module):
self.inp_length_factor = inp_length_multiplier
# make sure that project_hid & project_q are initialized like normal linear layers
self.project_hid = nn.Linear(inner_dim, self.quantizer.codevector_dim)
self.project_q = nn.Linear(self.quantizer.codevector_dim, self.quantizer.codevector_dim)
self.project_hid = bnb.nn.Linear8bitLt(inner_dim, self.quantizer.codevector_dim)
self.project_q = bnb.nn.Linear8bitLt(self.quantizer.codevector_dim, self.quantizer.codevector_dim)
self.reconstruction = do_reconstruction_loss
if do_reconstruction_loss:

View File

@ -2,6 +2,7 @@ import torch
import torch.nn.functional as F
from torch import nn
from transformers import GPT2Config, GPT2Model
import bitsandbytes as bnb
from models.arch_util import AttentionBlock, ResBlock
from models.audio.tts.lucidrains_dvae import DiscreteVAE
@ -55,8 +56,9 @@ class ConditioningAR(nn.Module):
self.gpt = GPT2Model(self.config)
del self.gpt.wte # Unused, we'll do our own embeddings.
self.embeddings = nn.Embedding(num_vectors, dim)
self.head = nn.Linear(dim, num_vectors)
# nn.Embedding
self.embeddings = bnb.nn.StableEmbedding(num_vectors, dim)
self.head = bnb.nn.Linear8bitLt(dim, num_vectors)
def forward(self, cheater_codes, conditioning, code_lengths=None, return_latent=False):
unused_params = []

View File

@ -17,6 +17,7 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import bitsandbytes as bnb
from math import sqrt
@ -24,7 +25,7 @@ from torch.utils.checkpoint import checkpoint
from trainer.networks import register_model
Linear = nn.Linear
Linear = bnb.nn.Linear8bitLt
ConvTranspose2d = nn.ConvTranspose2d

View File

@ -4,6 +4,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import autocast
import bitsandbytes as bnb
from models.arch_util import ResBlock
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
@ -22,7 +23,8 @@ def is_sequence(t):
class MultiGroupEmbedding(nn.Module):
def __init__(self, tokens, groups, dim):
super().__init__()
self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)])
# nn.Embedding
self.m = nn.ModuleList([bnb.nn.StableEmbedding(tokens, dim // groups) for _ in range(groups)])
def forward(self, x):
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
@ -158,7 +160,8 @@ class FlatDiffusion(nn.Module):
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
# transformer network.
if in_groups is None:
self.embeddings = nn.Embedding(token_count, model_channels)
# nn.Embedding
self.embeddings = bnb.nn.StableEmbedding(token_count, model_channels)
else:
self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels)
self.latent_conditioner = nn.Sequential(

View File

@ -2,6 +2,7 @@ import torch
from torch import nn
import torch.nn.functional as F
from transformers import GPT2Config, GPT2Model
import bitsandbytes as bnb
from models.arch_util import AttentionBlock, ResBlock
from models.audio.music.music_quantizer import MusicQuantizer
@ -136,8 +137,9 @@ class GptMusicLower(nn.Module):
self.gpt = GPT2Model(self.config)
del self.gpt.wte # Unused, we'll do our own embeddings.
self.embeddings = nn.ModuleList([nn.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)])
self.heads = nn.ModuleList([nn.Linear(dim, num_target_vectors) for _ in range(num_vaes)])
# nn.Embedding
self.embeddings = nn.ModuleList([bnb.nn.StableEmbedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)])
self.heads = nn.ModuleList([bnb.nn.Linear8bitLt(dim, num_target_vectors) for _ in range(num_vaes)])
def forward(self, mel, conditioning, return_latent=False):
unused_params = []
@ -238,8 +240,9 @@ class GptMusicUpper(nn.Module):
self.gpt = GPT2Model(self.config)
del self.gpt.wte # Unused, we'll do our own embeddings.
self.embeddings = nn.ModuleList([nn.Embedding(num_upper_vectors, dim // num_upper_groups) for _ in range(num_upper_groups)])
self.heads = nn.ModuleList([nn.Linear(dim, num_upper_vectors) for _ in range(num_upper_groups)])
# nn.Embedding
self.embeddings = nn.ModuleList([bnb.nn.StableEmbedding(num_upper_vectors, dim // num_upper_groups) for _ in range(num_upper_groups)])
self.heads = nn.ModuleList([bnb.nn.Linear8bitLt(dim, num_upper_vectors) for _ in range(num_upper_groups)])
def forward(self, mel, conditioning, return_latent=False):

View File

@ -2,6 +2,7 @@ import torch
import torch.nn.functional as F
from torch import nn
from transformers import GPT2Config, GPT2Model
import bitsandbytes as bnb
from models.arch_util import AttentionBlock, ResBlock
from models.audio.tts.lucidrains_dvae import DiscreteVAE
@ -73,8 +74,9 @@ class GptMusicLower(nn.Module):
self.gpt = GPT2Model(self.config)
del self.gpt.wte # Unused, we'll do our own embeddings.
self.embeddings = nn.ModuleList([nn.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)])
self.heads = nn.ModuleList([nn.Linear(dim, num_target_vectors) for _ in range(num_vaes)])
# nn.Embedding
self.embeddings = nn.ModuleList([bnb.nn.StableEmbedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)])
self.heads = nn.ModuleList([bnb.nn.Linear8bitLt(dim, num_target_vectors) for _ in range(num_vaes)])
def forward(self, mel, return_latent=False):
unused_params = []

View File

@ -3,6 +3,7 @@ import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import bitsandbytes as bnb
from models.diffusion.nn import timestep_embedding
from models.lucidrains.vq import VectorQuantize
@ -21,8 +22,8 @@ class SelfClassifyingHead(nn.Module):
use_rmsnorm=True, ff_glu=True, do_checkpointing=False)
self.quantizer = VectorQuantize(out_dim, classes, use_cosine_sim=False, threshold_ema_dead_code=2,
sample_codebook_temp=init_temperature)
self.to_output = nn.Linear(dim, out_dim)
self.to_decoder = nn.Linear(out_dim, dim)
self.to_output = bnb.nn.Linear8bitLt(dim, out_dim)
self.to_decoder = bnb.nn.Linear8bitLt(out_dim, dim)
def do_ar_step(self, x, used_codes):
h = self.dec(x)
@ -90,7 +91,7 @@ class InstrumentQuantizer(nn.Module):
"""
super().__init__()
self.op_dim = op_dim
self.proj = nn.Linear(op_dim, dim)
self.proj = bnb.nn.Linear8bitLt(op_dim, dim)
self.encoder = nn.ModuleList([VectorResBlock(dim, dropout) for _ in range(enc_depth)])
self.heads = SelfClassifyingHead(dim, num_classes, op_dim, head_depth, class_seq_len, dropout, max_temp)
self.min_gumbel_temperature = min_temp

View File

@ -1,6 +1,7 @@
import torch
from torch import nn
import torch.nn.functional as F
import bitsandbytes as bnb
from transformers import GPT2Config, GPT2Model
from trainer.networks import register_model
@ -17,8 +18,9 @@ class Mel2VecCodesGpt(nn.Module):
n_inner=dim*2)
self.gpt = GPT2Model(self.config)
del self.gpt.wte # Unused, we'll do our own embeddings.
self.embeddings = nn.ModuleList([nn.Embedding(num_vectors, dim//num_groups) for _ in range(num_groups)])
self.heads = nn.ModuleList([nn.Linear(dim, num_vectors) for _ in range(num_groups)])
# nn.Embedding
self.embeddings = nn.ModuleList([bnb.nn.StableEmbedding(num_vectors, dim//num_groups) for _ in range(num_groups)])
self.heads = nn.ModuleList([bnb.nn.Linear8bitLt(dim, num_vectors) for _ in range(num_groups)])
def forward(self, codes):
assert codes.shape[-1] == self.num_groups

View File

@ -3,6 +3,7 @@ import functools
import torch
from torch import nn
import torch.nn.functional as F
import bitsandbytes as bnb
from models.arch_util import zero_module
from models.vqvae.vqvae import Quantize
@ -75,7 +76,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
self.codevectors = nn.Parameter(
torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups)
)
self.weight_proj = nn.Linear(proj_dim, self.num_groups * self.num_vars)
self.weight_proj = bnb.nn.Linear8bitLt(proj_dim, self.num_groups * self.num_vars)
# can be decayed for training
self.temperature = 2

View File

@ -3,6 +3,7 @@ import functools
import torch
from torch import nn
import torch.nn.functional as F
import bitsandbytes as bnb
from models.arch_util import zero_module
from models.vqvae.vqvae import Quantize
@ -87,7 +88,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
self.codevectors = nn.Parameter(
torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups)
)
self.weight_proj = nn.Linear(proj_dim, self.num_groups * self.num_vars)
self.weight_proj = bnb.nn.Linear8bitLt(proj_dim, self.num_groups * self.num_vars)
# can be decayed for training
self.temperature = 2

View File

@ -1,3 +1,4 @@
import itertools
import os
import random
@ -7,6 +8,7 @@ import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import torchvision
import bitsandbytes as bnb
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import TimestepBlock
@ -54,12 +56,12 @@ class ConcatAttentionBlock(TimestepBlock):
self.prenorm = RMSScaleShiftNorm(trunk_dim, embed_dim=time_embed_dim, bias=False)
if cond_projection:
self.tdim = trunk_dim+cond_dim_hidden
self.cond_project = nn.Linear(cond_dim_in, cond_dim_hidden)
self.cond_project = bnb.nn.Linear8bitLt(cond_dim_in, cond_dim_hidden)
else:
self.tdim = trunk_dim
self.block1 = SubBlock(self.tdim, contraction_dim, heads, dropout, use_conv)
self.block2 = SubBlock(self.tdim+contraction_dim*2, contraction_dim, heads, dropout, use_conv)
self.out = nn.Linear(contraction_dim*4, trunk_dim, bias=False)
self.out = bnb.nn.Linear8bitLt(contraction_dim*4, trunk_dim, bias=False)
self.out.weight.data.zero_()
def forward(self, x, cond, timestep_emb, rotary_emb):
@ -87,7 +89,7 @@ class ConditioningEncoder(nn.Module):
self.init = nn.Conv1d(cond_dim, embedding_dim, kernel_size=1)
self.time_proj = time_proj
if time_proj:
self.time_proj = nn.Linear(time_embed_dim, embedding_dim)
self.time_proj = bnb.nn.Linear8bitLt(time_embed_dim, embedding_dim)
self.attn = Encoder(
dim=embedding_dim,
depth=attn_blocks,

View File

@ -4,6 +4,7 @@ from time import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import bitsandbytes as bnb
from models.arch_util import ResBlock
from models.audio.music.gpt_music2 import UpperEncoder, GptMusicLower
@ -27,7 +28,8 @@ def is_sequence(t):
class MultiGroupEmbedding(nn.Module):
def __init__(self, tokens, groups, dim):
super().__init__()
self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)])
# nn.Embedding
self.m = nn.ModuleList([bnb.nn.StableEmbedding(tokens, dim // groups) for _ in range(groups)])
def forward(self, x):
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
@ -68,7 +70,7 @@ class ConcatAttentionBlock(TimestepBlock):
self.prenorm = RMSScaleShiftNorm(trunk_dim, embed_dim=time_embed_dim, bias=False)
self.block1 = SubBlock(trunk_dim, contraction_dim, heads, dropout)
self.block2 = SubBlock(trunk_dim+contraction_dim*2, contraction_dim, heads, dropout)
self.out = nn.Linear(contraction_dim*4, trunk_dim, bias=False)
self.out = bnb.nn.Linear8bitLt(contraction_dim*4, trunk_dim, bias=False)
self.out.weight.data.zero_()
def forward(self, x, timestep_emb, rotary_emb):
@ -129,7 +131,7 @@ class TransformerDiffusion(nn.Module):
)
prenet_heads = prenet_channels//64
self.input_converter = nn.Linear(input_vec_dim, prenet_channels)
self.input_converter = bnb.nn.Linear8bitLt(input_vec_dim, prenet_channels)
self.code_converter = Encoder(
dim=prenet_channels,
depth=prenet_layers,
@ -145,7 +147,7 @@ class TransformerDiffusion(nn.Module):
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels))
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
self.intg = nn.Linear(prenet_channels*2, model_channels)
self.intg = bnb.nn.Linear8bitLt(prenet_channels*2, model_channels)
self.layers = TimestepRotaryEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_embed_dim, num_heads, dropout) for _ in range(num_layers)])
self.out = nn.Sequential(

View File

@ -5,6 +5,7 @@ from random import randrange
import torch
import torch.nn as nn
import torch.nn.functional as F
import bitsandbytes as bnb
from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask, cGLU, \
RelativeQKBias
@ -69,13 +70,14 @@ class ConditioningEncoder(nn.Module):
super().__init__()
attn = []
self.init = nn.Conv1d(spec_dim, hidden_dim, kernel_size=5, stride=2)
self.resolution_embedding = nn.Embedding(num_resolutions, hidden_dim)
# nn.Embedding
self.resolution_embedding = bnb.nn.StableEmbedding(num_resolutions, hidden_dim)
self.resolution_embedding.weight.data.mul(.1) # Reduces the relative influence of this embedding from the start.
for a in range(attn_blocks):
attn.append(AttentionBlock(hidden_dim, num_attn_heads, do_checkpoint=do_checkpointing))
attn.append(ResBlock(hidden_dim, dims=1, checkpointing_enabled=do_checkpointing))
self.attn = nn.Sequential(*attn)
self.out = nn.Linear(hidden_dim, out_dim, bias=False)
self.out = bnb.nn.Linear8bitLt(hidden_dim, out_dim, bias=False)
self.dim = hidden_dim
self.do_checkpointing = do_checkpointing
@ -131,7 +133,8 @@ class TransformerDiffusion(nn.Module):
nn.SiLU(),
linear(time_embed_dim, time_proj_dim),
)
self.resolution_embed = nn.Embedding(resolution_steps, time_proj_dim)
# nn.Embedding
self.resolution_embed = bnb.nn.StableEmbedding(resolution_steps, time_proj_dim)
self.conditioning_encoder = ConditioningEncoder(in_channels, model_channels, cond_proj_dim, resolution_steps, num_attn_heads=model_channels//64)
self.unconditioned_embedding = nn.Parameter(torch.randn(1,cond_proj_dim))

View File

@ -8,6 +8,7 @@ import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torchvision # For debugging, not actually used.
import bitsandbytes as bnb
from models.audio.music.gpt_music import GptMusicLower
from models.audio.music.music_quantizer import MusicQuantizer
@ -490,7 +491,7 @@ class UNetMusicModel(nn.Module):
)
if self.ar_prior:
self.ar_input = nn.Linear(input_vec_dim, model_channels)
self.ar_input = bnb.nn.Linear8bitLt(input_vec_dim, model_channels)
self.ar_prior_intg = Encoder(
dim=model_channels,
depth=4,
@ -504,7 +505,7 @@ class UNetMusicModel(nn.Module):
ff_mult=1,
)
else:
self.input_converter = nn.Linear(input_vec_dim, model_channels)
self.input_converter = bnb.nn.Linear8bitLt(input_vec_dim, model_channels)
self.code_converter = Encoder(
dim=model_channels,
depth=4,
@ -521,7 +522,8 @@ class UNetMusicModel(nn.Module):
self.x_processor = conv_nd(dims, in_channels, model_channels, 3, padding=1)
if self.num_classes is not None:
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
# nn.Embedding
self.label_emb = bnb.nn.StableEmbedding(num_classes, time_embed_dim)
self.use_raw_y_as_embedding = use_raw_y_as_embedding
assert not ((self.num_classes is not None) and use_raw_y_as_embedding) # These are mutually-exclusive.

View File

@ -3,6 +3,7 @@ from random import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import bitsandbytes as bnb
from models.audio.tts.unet_diffusion_tts7 import CheckpointedLayer
from models.lucidrains.x_transformers import Encoder
@ -36,9 +37,12 @@ class CtcCodeGenerator(nn.Module):
self.ctc_codes = ctc_codes
pred_codes = (max_pad+1)*(max_repeat+1)
self.position_embedding = nn.Embedding(max_length, model_dim)
self.codes_embedding = nn.Embedding(ctc_codes, model_dim)
self.recursive_embedding = nn.Embedding(pred_codes, model_dim)
# nn.Embedding
self.position_embedding = bnb.nn.StableEmbedding(max_length, model_dim)
# nn.Embedding
self.codes_embedding = bnb.nn.StableEmbedding(ctc_codes, model_dim)
# nn.Embedding
self.recursive_embedding = bnb.nn.StableEmbedding(pred_codes, model_dim)
self.mask_embedding = nn.Parameter(torch.randn(model_dim))
self.encoder = Encoder(
dim=model_dim,
@ -50,8 +54,8 @@ class CtcCodeGenerator(nn.Module):
ff_glu=True,
rotary_pos_emb=True,
)
self.pred_head = nn.Linear(model_dim, pred_codes)
self.confidence_head = nn.Linear(model_dim, 1)
self.pred_head = bnb.nn.Linear8bitLt(model_dim, pred_codes)
self.confidence_head = bnb.nn.Linear8bitLt(model_dim, 1)
def inference(self, codes, pads, repeats):
position_h = self.position_embedding(torch.arange(0, codes.shape[-1], device=codes.device))

View File

@ -5,6 +5,8 @@ from functools import partial
import torch
import torch.nn as nn
import bitsandbytes as bnb
from x_transformers.x_transformers import groupby_prefix_and_trim, FixedPositionalEmbedding, default, RotaryEmbedding, \
DEFAULT_DIM_HEAD, RelativePositionBias, LearnedAlibiPositionalBias, AlibiPositionalBias, ScaleNorm, RMSNorm, \
exists, Attention, FeedForward, Scale, ShiftTokens, GRUGating, Residual, cast_tuple, equals, LayerIntermediates, \
@ -16,7 +18,7 @@ class TimeIntegrationBlock(nn.Module):
super().__init__()
self.emb_layers = nn.Sequential(
nn.SiLU(),
nn.Linear(
bnb.nn.Linear8bitLt(
time_emb_dim,
2 * dim
),

View File

@ -1,5 +1,6 @@
import torch
import torch.nn as nn
import bitsandbytes as bnb
from models.diffusion.nn import normalization, conv_nd, zero_module
@ -138,7 +139,7 @@ class AudioMiniEncoderWithClassifierHead(nn.Module):
def __init__(self, classes, distribute_zero_label=True, **kwargs):
super().__init__()
self.enc = AudioMiniEncoder(**kwargs)
self.head = nn.Linear(self.enc.dim, classes)
self.head = bnb.nn.Linear8bitLt(self.enc.dim, classes)
self.num_classes = classes
self.distribute_zero_label = distribute_zero_label
@ -183,7 +184,7 @@ class QueryProvidedAttentionBlock(nn.Module):
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.norm = normalization(channels)
self.q = nn.Linear(channels, channels)
self.q = bnb.nn.Linear8bitLt(channels, channels)
self.qnorm = nn.LayerNorm(channels)
self.kv = conv_nd(1, channels, channels*2, 1)
if use_new_attention_order:

View File

@ -3,6 +3,7 @@ import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import bitsandbytes as bnb
from trainer.networks import register_model
from utils.util import opt_get
@ -44,7 +45,7 @@ class RandomLatentConverter(nn.Module):
def __init__(self, channels):
super().__init__()
self.layers = nn.Sequential(*[EqualLinear(channels, channels, lr_mul=.1) for _ in range(5)],
nn.Linear(channels, channels))
bnb.nn.Linear8bitLt(channels, channels))
self.channels = channels
def forward(self, ref):

View File

@ -3,12 +3,13 @@ from librosa.filters import mel as librosa_mel_fn
from models.audio.tts.tacotron2.audio_processing import dynamic_range_compression
from models.audio.tts.tacotron2.audio_processing import dynamic_range_decompression
from models.audio.tts.tacotron2.stft import STFT
import bitsandbytes as bnb
class LinearNorm(torch.nn.Module):
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
super(LinearNorm, self).__init__()
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
self.linear_layer = torch.bnb.nn.Linear8bitLt(in_dim, out_dim, bias=bias)
torch.nn.init.xavier_uniform_(
self.linear_layer.weight,

View File

@ -8,6 +8,7 @@ from models.audio.tts.tacotron2.layers import ConvNorm, LinearNorm
from models.audio.tts.tacotron2.hparams import create_hparams
from trainer.networks import register_model
from models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths
import bitsandbytes as bnb
class LocationLayer(nn.Module):
@ -463,7 +464,8 @@ class Tacotron2(nn.Module):
self.fp16_run = hparams.fp16_run
self.n_mel_channels = hparams.n_mel_channels
self.n_frames_per_step = hparams.n_frames_per_step
self.embedding = nn.Embedding(
# nn.Embedding
self.embedding = bnb.nn.StableEmbedding(
hparams.n_symbols, hparams.symbols_embedding_dim)
std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
val = sqrt(3.0) * std # uniform bounds for std

View File

@ -13,6 +13,7 @@ from models.audio.tts.tacotron2.tacotron2 import Attention, Encoder
from trainer.networks import register_model
from models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths
from utils.util import checkpoint
import bitsandbytes as bnb
@ -185,7 +186,8 @@ class WaveTacotron2(nn.Module):
self.fp16_run = hparams.fp16_run
self.n_mel_channels = hparams.n_mel_channels
self.n_frames_per_step = hparams.n_frames_per_step
self.embedding = nn.Embedding(
# nn.Embedding
self.embedding = bnb.nn.StableEmbedding(
hparams.n_symbols, hparams.symbols_embedding_dim)
std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
val = sqrt(3.0) * std # uniform bounds for std

View File

@ -25,6 +25,7 @@ import random
from time import time
import torch
import torch.nn as nn
import bitsandbytes as bnb
from tqdm import tqdm
@ -35,7 +36,8 @@ def null_position_embeddings(range, dim):
class LearnedPositionEmbeddings(nn.Module):
def __init__(self, seq_len, model_dim, init=.02, relative=False):
super().__init__()
self.emb = nn.Embedding(seq_len, model_dim)
# nn.Embedding
self.emb = bnb.nn.StableEmbedding(seq_len, model_dim)
# Initializing this way is standard for GPT-2
self.emb.weight.data.normal_(mean=0.0, std=init)
self.relative = relative

View File

@ -7,6 +7,7 @@ from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlo
from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding
from trainer.networks import register_model
from utils.util import checkpoint
import bitsandbytes as bnb
def is_latent(t):
@ -19,7 +20,8 @@ def is_sequence(t):
class MultiGroupEmbedding(nn.Module):
def __init__(self, tokens, groups, dim):
super().__init__()
self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)])
# nn.Embedding
self.m = nn.ModuleList([bnb.nn.StableEmbedding(tokens, dim // groups) for _ in range(groups)])
def forward(self, x):
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
@ -100,15 +102,17 @@ class TransformerDiffusionTTS(nn.Module):
ff_glu=True,
rotary_pos_emb=True,
)
self.clvp_encoder = nn.Linear(clvp_in_dim, model_channels)
self.type_embedding = nn.Embedding(types, model_channels)
self.clvp_encoder = bnb.nn.Linear8bitLt(clvp_in_dim, model_channels)
# nn.Embedding
self.type_embedding = bnb.nn.StableEmbedding(types, model_channels)
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
# transformer network.
if in_groups is None:
self.embeddings = nn.Embedding(token_count, model_channels)
# nn.Embedding
self.embeddings = bnb.nn.StableEmbedding(token_count, model_channels)
else:
self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels)
self.latent_conditioner = nn.Sequential(
@ -140,7 +144,7 @@ class TransformerDiffusionTTS(nn.Module):
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
self.intg = nn.Linear(model_channels*2, model_channels)
self.intg = bnb.nn.Linear8bitLt(model_channels*2, model_channels)
self.layers = TimestepRotaryEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers)])
self.out = nn.Sequential(

View File

@ -1,6 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import bitsandbytes as bnb
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock
@ -19,7 +20,8 @@ def is_sequence(t):
class MultiGroupEmbedding(nn.Module):
def __init__(self, tokens, groups, dim):
super().__init__()
self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)])
# nn.Embedding
self.m = nn.ModuleList([bnb.nn.StableEmbedding(tokens, dim // groups) for _ in range(groups)])
def forward(self, x):
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
@ -40,7 +42,7 @@ class DietAttentionBlock(TimestepBlock):
def __init__(self, in_dim, dim, heads, dropout):
super().__init__()
self.rms_scale_norm = RMSScaleShiftNorm(in_dim)
self.proj = nn.Linear(in_dim, dim)
self.proj = bnb.nn.Linear8bitLt(in_dim, dim)
self.attn = Attention(dim, heads=heads, causal=False, dropout=dropout)
self.ff = FeedForward(dim, in_dim, mult=1, dropout=dropout, zero_init_output=True)
@ -105,15 +107,17 @@ class TransformerDiffusionTTS(nn.Module):
ff_glu=True,
rotary_pos_emb=True,
)
self.clvp_encoder = nn.Linear(clvp_in_dim, prenet_channels)
self.type_embedding = nn.Embedding(types, prenet_channels)
self.clvp_encoder = bnb.nn.Linear8bitLt(clvp_in_dim, prenet_channels)
# nn.Embedding
self.type_embedding = bnb.nn.StableEmbedding(types, prenet_channels)
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
# transformer network.
if in_groups is None:
self.embeddings = nn.Embedding(token_count, prenet_channels)
# nn.Embedding
self.embeddings = bnb.nn.StableEmbedding(token_count, prenet_channels)
else:
self.embeddings = MultiGroupEmbedding(token_count, in_groups, prenet_channels)
self.latent_conditioner = nn.Sequential(
@ -144,8 +148,8 @@ class TransformerDiffusionTTS(nn.Module):
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels))
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
self.cond_intg = nn.Linear(prenet_channels*4, model_channels)
self.intg = nn.Linear(prenet_channels*2, model_channels)
self.cond_intg = bnb.nn.Linear8bitLt(prenet_channels*4, model_channels)
self.intg = bnb.nn.Linear8bitLt(prenet_channels*2, model_channels)
self.layers = TimestepRotaryEmbedSequential(*[DietAttentionBlock(model_channels, block_channels, block_channels // 64, dropout) for _ in range(num_layers)])

View File

@ -5,6 +5,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import autocast
import bitsandbytes as bnb
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \
@ -247,14 +248,16 @@ class DiffusionTts(nn.Module):
)
embedding_dim = model_channels * 8
self.code_embedding = nn.Embedding(num_tokens+1, embedding_dim)
# nn.Embedding
self.code_embedding = bnb.nn.StableEmbedding(num_tokens+1, embedding_dim)
self.contextual_embedder = AudioMiniEncoder(1, embedding_dim, base_channels=32, depth=6, resnet_blocks=1,
attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
self.conditioning_conv = nn.Conv1d(embedding_dim*3, embedding_dim, 1)
self.enable_unaligned_inputs = enabled_unaligned_inputs
if enabled_unaligned_inputs:
self.unaligned_embedder = nn.Embedding(num_unaligned_tokens, embedding_dim)
# nn.Embedding
self.unaligned_embedder = bnb.nn.StableEmbedding(num_unaligned_tokens, embedding_dim)
self.unaligned_encoder = CheckpointedXTransformerEncoder(
max_seq_len=-1,
use_pos_emb=False,

View File

@ -5,6 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch import autocast
from x_transformers import Encoder
import bitsandbytes as bnb
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \
@ -206,7 +207,8 @@ class DiffusionTts(nn.Module):
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
# transformer network.
self.code_converter = nn.Sequential(
nn.Embedding(in_tokens, conditioning_dim),
# nn.Embedding
bnb.nn.StableEmbedding(in_tokens, conditioning_dim),
CheckpointedXTransformerEncoder(
needs_permute=False,
max_seq_len=-1,

View File

@ -5,6 +5,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import autocast
import bitsandbytes as bnb
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock, QKVAttentionLegacy
@ -193,7 +194,9 @@ class DiffusionTtsFlat(nn.Module):
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
# transformer network.
self.code_embedding = nn.Embedding(in_tokens, model_channels)
# nn.Embedding
self.code_embedding = bnb.nn.StableEmbedding(in_tokens, model_channels)
self.code_converter = nn.Sequential(
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),

View File

@ -1,6 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Config, GPT2PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
@ -12,6 +13,7 @@ from models.lucidrains.x_transformers import RotaryEmbedding, apply_rotary_pos_e
from trainer.networks import register_model
from utils.util import opt_get
import bitsandbytes as bnb
class ResBlock(nn.Module):
"""
@ -279,9 +281,11 @@ class UnifiedVoice(nn.Module):
self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.average_conditioning_embeddings = average_conditioning_embeddings
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
# nn.Embedding
self.text_embedding = bnb.nn.StableEmbedding(self.number_text_tokens, model_dim)
if use_mel_codes_as_input:
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
# nn.Embedding
self.mel_embedding = bnb.nn.StableEmbedding(self.number_mel_codes, model_dim)
else:
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
@ -294,8 +298,8 @@ class UnifiedVoice(nn.Module):
self.text_solo_embedding = 0
self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
self.text_head = bnb.nn.Linear8bitLt(model_dim, self.number_text_tokens)
self.mel_head = bnb.nn.Linear8bitLt(model_dim, self.number_mel_codes)
# Initialize the embeddings per the GPT-2 scheme
embeddings = [self.text_embedding]

View File

@ -1,6 +1,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import bitsandbytes as bnb
from transformers import GPT2Config, GPT2PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
@ -271,15 +273,17 @@ class UnifiedVoice(nn.Module):
self.model_dim = model_dim
self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim)
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
# nn.Embedding
self.text_embedding = bnb.nn.StableEmbedding(self.number_text_tokens*types+1, model_dim)
# nn.Embedding
self.mel_embedding = bnb.nn.StableEmbedding(self.number_mel_codes, model_dim)
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing)
self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1)
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
self.aligned_head = nn.Linear(model_dim, number_aligned_text_codes)
self.text_head = bnb.nn.Linear8bitLt(model_dim, self.number_text_tokens*types+1)
self.mel_head = bnb.nn.Linear8bitLt(model_dim, self.number_mel_codes)
self.aligned_head = bnb.nn.Linear8bitLt(model_dim, number_aligned_text_codes)
# Initialize the embeddings per the GPT-2 scheme
embeddings = [self.text_embedding, self.mel_embedding]

View File

@ -11,6 +11,7 @@ from models.audio.tts.transformer_builders import build_hf_gpt_transformer
from models.lucidrains.x_transformers import RotaryEmbedding, apply_rotary_pos_emb
from trainer.networks import register_model
from utils.util import opt_get
import bitsandbytes as bnb
class ResBlock(nn.Module):
@ -255,15 +256,17 @@ class UnifiedVoice(nn.Module):
self.model_dim = model_dim
self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim)
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
# nn.Embedding
self.text_embedding = bnb.nn.StableEmbedding(self.number_text_tokens*types+1, model_dim)
# nn.Embedding
self.mel_embedding = bnb.nn.StableEmbedding(self.number_mel_codes, model_dim)
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing)
self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1)
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
self.alignment_head = nn.Linear(model_dim, 256)
self.text_head = bnb.nn.Linear8bitLt(model_dim, self.number_text_tokens*types+1)
self.mel_head = bnb.nn.Linear8bitLt(model_dim, self.number_mel_codes)
self.alignment_head = bnb.nn.Linear8bitLt(model_dim, 256)
if only_alignment_head:
for p in self.parameters():

View File

@ -8,6 +8,7 @@ from models.audio.tts.mini_encoder import AudioMiniEncoder
from trainer.injectors.spec_augment import spec_augment
from trainer.networks import register_model
from utils.util import opt_get
import bitsandbytes as bnb
def exists(val):
@ -36,7 +37,7 @@ class VoiceCLIP(nn.Module):
self.encoder = AudioMiniEncoder(80, encoder_output)
if pretrained_encoder_dict_path is not None:
self.encoder.load_state_dict(torch.load(pretrained_encoder_dict_path))
self.to_latent = nn.Linear(encoder_output, dim_latent, bias=False)
self.to_latent = bnb.nn.Linear8bitLt(encoder_output, dim_latent, bias=False)
self.temperature = nn.Parameter(torch.tensor(1.))
self.mel_compression_ratio = mel_compression_ratio

View File

@ -7,6 +7,7 @@ from x_transformers import Encoder, Decoder, ContinuousTransformerWrapper
from models.audio.tts.mini_encoder import AudioMiniEncoder
from trainer.networks import register_model
import bitsandbytes as bnb
class CheckpointedLayer(nn.Module):
@ -56,7 +57,8 @@ class Wav2VecMatcher(nn.Module):
WAV2VEC_CHANNELS = 1024
self.conditioning_encoder = AudioMiniEncoder(1, model_dim, base_channels=32, depth=6, resnet_blocks=1,
attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
self.text_embedding = nn.Embedding(num_text_tokens, model_dim)
# nn.Embedding
self.text_embedding = bnb.nn.StableEmbedding(num_text_tokens, model_dim)
self.encoder = CheckpointedXTransformer(
max_seq_len=-1,
use_pos_emb=False,
@ -73,8 +75,8 @@ class Wav2VecMatcher(nn.Module):
)
self.decoder_start_embedding = nn.Parameter(torch.randn(1,1,model_dim))
self.decoder_stop_embedding = nn.Parameter(torch.randn(1,model_dim))
self.w2v_query_encoder = nn.Linear(WAV2VEC_CHANNELS, model_dim)
self.w2v_value_encoder = nn.Linear(WAV2VEC_CHANNELS, model_dim)
self.w2v_query_encoder = bnb.nn.Linear8bitLt(WAV2VEC_CHANNELS, model_dim)
self.w2v_value_encoder = bnb.nn.Linear8bitLt(WAV2VEC_CHANNELS, model_dim)
self.decoder = CheckpointedXTransformer(
max_seq_len=-1, # Should be unused
use_pos_emb=False,

View File

@ -10,6 +10,7 @@
import torch
import torch.nn as nn
import bitsandbytes as bnb
from trainer.networks import register_model
@ -98,7 +99,7 @@ class ResNet(nn.Module):
self.conv4_x = self._make_layer(block, 128, num_block[2], 2)
self.conv5_x = self._make_layer(block, 256, num_block[3], 2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(256 * block.expansion, num_classes)
self.fc = bnb.nn.Linear8bitLt(256 * block.expansion, num_classes)
def _make_layer(self, block, out_channels, num_blocks, stride):
"""make resnet layers(by layer i didnt mean this 'layer' was the

View File

@ -4,6 +4,7 @@ import torch
import torch.nn as nn
from torchvision.models.resnet import BasicBlock, Bottleneck
import torchvision
import bitsandbytes as bnb
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
@ -194,5 +195,5 @@ def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
def register_resnet50(opt_net, opt):
model = resnet50(pretrained=opt_net['pretrained'])
if opt_net['custom_head_logits']:
model.fc = nn.Linear(512 * 4, opt_net['custom_head_logits'])
model.fc = bnb.nn.Linear8bitLt(512 * 4, opt_net['custom_head_logits'])
return model

View File

@ -11,6 +11,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import bitsandbytes as bnb
from trainer.networks import register_model
@ -101,7 +102,7 @@ class ResNet(nn.Module):
self.conv4_x = self._make_layer(block, 128, num_block[2], 2)
self.conv5_x = self._make_layer(block, 256, num_block[3], 2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(256 * block.expansion, num_classes)
self.fc = bnb.nn.Linear8bitLt(256 * block.expansion, num_classes)
def _make_layer(self, block, out_channels, num_blocks, stride):
"""make resnet layers(by layer i didnt mean this 'layer' was the

View File

@ -11,6 +11,7 @@ __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
from models.vqvae.scaled_weight_conv import ScaledWeightConv
from trainer.networks import register_model
from utils.util import checkpoint
import bitsandbytes as bnb
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
@ -213,7 +214,7 @@ class ResNet(nn.Module):
self.layer4 = self._make_layer(block, 512, layers[3], breadth, stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
self.fc = bnb.nn.Linear8bitLt(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, ScaledWeightConv):

View File

@ -3,7 +3,7 @@ import torch.nn as nn
from trainer.networks import register_model
from utils.util import opt_get
import bitsandbytes as bnb
class WideKernelVgg(nn.Module):
def __init__(self, nf=64, num_classes=2):
@ -49,9 +49,9 @@ class WideKernelVgg(nn.Module):
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.Flatten(),
nn.Linear(nf * 8 * 4 * 2, 100),
bnb.nn.Linear8bitLt(nf * 8 * 4 * 2, 100),
nn.ReLU(),
nn.Linear(100, num_classes)
bnb.nn.Linear8bitLt(100, num_classes)
)
# These normalization constants should be derived experimentally.

View File

@ -10,6 +10,7 @@ from models.arch_util import AttentionBlock
from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder
from trainer.networks import register_model
from utils.util import opt_get, checkpoint
import bitsandbytes as bnb
def exists(val):
@ -58,7 +59,8 @@ class CollapsingTransformer(nn.Module):
class ConvFormatEmbedding(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.emb = nn.Embedding(*args, **kwargs)
# nn.Embedding
self.emb = bnb.nn.StableEmbedding(*args, **kwargs)
def forward(self, x):
y = self.emb(x)
@ -98,9 +100,10 @@ class CLVP(nn.Module):
self.masked_conditioning_latent = nn.Parameter(torch.randn(1,model_dim*2), requires_grad=True)
self.mask_conditioning_percentage = mask_conditioning_percentage
self.text_emb = nn.Embedding(num_text_tokens, model_dim)
# nn.Embedding
self.text_emb = bnb.nn.StableEmbedding(num_text_tokens, model_dim)
self.text_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, text_enc_depth, text_mask_percentage, use_rms_scaleshift_norm=True)
self.to_text_latent = nn.Linear(latent_dim, latent_dim, bias=False)
self.to_text_latent = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False)
self.distributed_collect = distributed_collect
if mel_codes is None:
@ -108,7 +111,7 @@ class CLVP(nn.Module):
else:
self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim)
self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage)
self.to_speech_latent = nn.Linear(latent_dim, latent_dim, bias=False)
self.to_speech_latent = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False)
def get_grad_norm_parameter_groups(self):
return {

View File

@ -9,6 +9,7 @@ from models.arch_util import AttentionBlock
from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder
from trainer.networks import register_model
from utils.util import opt_get, checkpoint
import bitsandbytes as bnb
def exists(val):
@ -178,7 +179,8 @@ class CollapsingTransformer(nn.Module):
class ConvFormatEmbedding(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.emb = nn.Embedding(*args, **kwargs)
# nn.Embedding
self.emb = bnb.nn.StableEmbedding(*args, **kwargs)
def forward(self, x):
y = self.emb(x)
@ -203,8 +205,8 @@ class ContrastiveAudio(nn.Module):
self.emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim // 2, kernel_size=5, stride=2, padding=2),
nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1))
self.transformer = CollapsingTransformer(model_dim, model_dim, transformer_heads, dropout, encoder_depth, mask_percent)
self.to_latent = nn.Linear(latent_dim, latent_dim, bias=False)
self.to_latent2 = nn.Linear(latent_dim, latent_dim, bias=False)
self.to_latent = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False)
self.to_latent2 = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False)
self.to_latent2.weight.data = self.to_latent.weight.data
self.to_latent2.weight.DO_NOT_TRAIN = True

View File

@ -10,6 +10,7 @@ from models.arch_util import AttentionBlock
from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder
from trainer.networks import register_model
from utils.util import opt_get, checkpoint
import bitsandbytes as bnb
def exists(val):
@ -58,7 +59,8 @@ class CollapsingTransformer(nn.Module):
class ConvFormatEmbedding(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.emb = nn.Embedding(*args, **kwargs)
# nn.Embedding
self.emb = bnb.nn.StableEmbedding(*args, **kwargs)
def forward(self, x):
y = self.emb(x)
@ -86,14 +88,14 @@ class CVVP(nn.Module):
self.cond_emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim//2, kernel_size=5, stride=2, padding=2),
nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1))
self.conditioning_transformer = CollapsingTransformer(model_dim, model_dim, transformer_heads, dropout, conditioning_enc_depth, cond_mask_percentage)
self.to_conditioning_latent = nn.Linear(latent_dim, latent_dim, bias=False)
self.to_conditioning_latent = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False)
if mel_codes is None:
self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2)
else:
self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim)
self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage)
self.to_speech_latent = nn.Linear(latent_dim, latent_dim, bias=False)
self.to_speech_latent = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False)
def get_grad_norm_parameter_groups(self):
return {

View File

@ -7,6 +7,7 @@ from torch import einsum
from models.lucidrains.dalle.transformer import Transformer
from trainer.networks import register_model
from utils.util import opt_get
import bitsandbytes as bnb
def exists(val):
@ -45,17 +46,20 @@ class MelTextCLIP(nn.Module):
mel_compression=256,
):
super().__init__()
self.text_emb = nn.Embedding(num_text_tokens, dim_text)
self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
# nn.Embedding
self.text_emb = bnb.nn.StableEmbedding(num_text_tokens, dim_text)
# nn.Embedding
self.text_pos_emb = bnb.nn.StableEmbedding(text_seq_len, dim_text)
self.text_transformer = Transformer(causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth,
heads=text_heads, rotary_emb=False)
self.to_text_latent = nn.Linear(dim_text, dim_latent, bias=False)
self.to_text_latent = bnb.nn.Linear8bitLt(dim_text, dim_latent, bias=False)
self.speech_enc = nn.Conv1d(80, dim_speech, kernel_size=3, padding=1)
self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
# nn.Embedding
self.speech_pos_emb = bnb.nn.StableEmbedding(num_speech_tokens, dim_speech)
self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech,
depth=speech_enc_depth, heads=speech_heads, rotary_emb=False)
self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False)
self.to_speech_latent = bnb.nn.Linear8bitLt(dim_speech, dim_latent, bias=False)
self.temperature = nn.Parameter(torch.tensor(1.))
self.text_mask_percentage = text_mask_percentage

View File

@ -7,6 +7,7 @@ from models.audio.tts.unified_voice2 import ConditioningEncoder
from models.lucidrains.dalle.transformer import Transformer
from trainer.networks import register_model
from utils.util import opt_get
import bitsandbytes as bnb
def exists(val):
@ -45,7 +46,7 @@ class VoiceCondCLIP(nn.Module):
self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech,
depth=speech_enc_depth, heads=speech_heads, rotary_emb=False)
self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False)
self.to_speech_latent = bnb.nn.Linear8bitLt(dim_speech, dim_latent, bias=False)
self.temperature = nn.Parameter(torch.tensor(1.))
self.voice_mask_percentage = voice_mask_percentage

View File

@ -11,6 +11,7 @@ from models.audio.tts.unet_diffusion_tts7 import CheckpointedXTransformerEncoder
from models.lucidrains.dalle.transformer import Transformer
from trainer.networks import register_model
from utils.util import opt_get
import bitsandbytes as bnb
def exists(val):
@ -53,11 +54,13 @@ class VoiceCLIP(nn.Module):
distributed_collect=False,
):
super().__init__()
self.text_emb = nn.Embedding(num_text_tokens, dim_text)
self.to_text_latent = nn.Linear(dim_text, dim_latent, bias=False)
# nn.Embedding
self.text_emb = bnb.nn.StableEmbedding(num_text_tokens, dim_text)
self.to_text_latent = bnb.nn.Linear8bitLt(dim_text, dim_latent, bias=False)
self.speech_emb = nn.Embedding(num_speech_tokens, dim_speech)
self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False)
# nn.Embedding
self.speech_emb = bnb.nn.StableEmbedding(num_speech_tokens, dim_speech)
self.to_speech_latent = bnb.nn.Linear8bitLt(dim_speech, dim_latent, bias=False)
if use_xformers:
self.text_transformer = CheckpointedXTransformerEncoder(
@ -105,8 +108,10 @@ class VoiceCLIP(nn.Module):
self.min_mel_size = min_mel_size
self.distributed_collect = distributed_collect
if not use_xformers:
self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
# nn.Embedding
self.text_pos_emb = bnb.nn.StableEmbedding(text_seq_len, dim_text)
# nn.Embedding
self.speech_pos_emb = bnb.nn.StableEmbedding(num_speech_tokens, dim_speech)
def embed_text(self, text):
text_mask = torch.ones_like(text.float()).bool()

View File

@ -6,6 +6,7 @@ import math
import torch as th
import torch.nn as nn
import bitsandbytes as bnb
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
@ -36,7 +37,7 @@ def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
return bnb.nn.Linear8bitLt(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):

View File

@ -6,6 +6,7 @@ from models.arch_util import ConvGnLelu, default_init_weights, make_layer
from models.diffusion.nn import timestep_embedding
from trainer.networks import register_model
from utils.util import checkpoint
import bitsandbytes as bnb
# Conditionally uses torch's checkpoint functionality if it is enabled in the opt file.
@ -28,7 +29,7 @@ class ResidualDenseBlock(nn.Module):
self.first_conv = ConvGnLelu(mid_channels, mid_channels, activation=True, norm=False, bias=True)
self.emb_layers = nn.Sequential(
nn.SiLU(),
nn.Linear(
bnb.nn.Linear8bitLt(
mid_channels*4,
mid_channels,
),
@ -143,9 +144,9 @@ class RRDBNet(nn.Module):
# Guided diffusion uses a time embedding.
time_embed_dim = mid_channels * 4
self.time_embed = nn.Sequential(
nn.Linear(mid_channels, time_embed_dim),
bnb.nn.Linear8bitLt(mid_channels, time_embed_dim),
nn.SiLU(),
nn.Linear(time_embed_dim, time_embed_dim),
bnb.nn.Linear8bitLt(time_embed_dim, time_embed_dim),
)
self.body = make_layer(

View File

@ -20,6 +20,7 @@ from models.diffusion.nn import (
)
from trainer.networks import register_model
from utils.util import checkpoint
import bitsandbytes as bnb
class AttentionPool2d(nn.Module):
@ -515,7 +516,8 @@ class UNetModel(nn.Module):
)
if self.num_classes is not None:
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
# nn.Embedding
self.label_emb = bnb.nn.StableEmbedding(num_classes, time_embed_dim)
self.use_raw_y_as_embedding = use_raw_y_as_embedding
assert not ((self.num_classes is not None) and use_raw_y_as_embedding) # These are mutually-exclusive.
@ -867,16 +869,16 @@ class EncoderUNetModel(nn.Module):
)
elif pool == "spatial":
self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048),
bnb.nn.Linear8bitLt(self._feature_size, 2048),
nn.ReLU(),
nn.Linear(2048, self.out_channels),
bnb.nn.Linear8bitLt(2048, self.out_channels),
)
elif pool == "spatial_v2":
self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048),
bnb.nn.Linear8bitLt(self._feature_size, 2048),
normalization(2048),
nn.SiLU(),
nn.Linear(2048, self.out_channels),
bnb.nn.Linear8bitLt(2048, self.out_channels),
)
else:
raise NotImplementedError(f"Unexpected {pool} pooling")

View File

@ -26,6 +26,7 @@ from models.diffusion.nn import (
)
from trainer.networks import register_model
from utils.util import checkpoint
import bitsandbytes as bnb
class AttentionPool2d(nn.Module):
@ -476,7 +477,8 @@ class UNetModel(nn.Module):
)
if self.num_classes is not None:
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
# nn.Embedding
self.label_emb = bnb.nn.StableEmbedding(num_classes, time_embed_dim)
self.input_blocks = nn.ModuleList(
[
@ -736,7 +738,7 @@ class ResNetEncoder(nn.Module):
dilate=replace_stride_with_dilation[2])
f=512
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(f * block.expansion, output_dim)
self.fc = bnb.nn.Linear8bitLt(f * block.expansion, output_dim)
for m in self.modules():
if isinstance(m, nn.Conv2d):

View File

@ -5,6 +5,7 @@ import torch.nn.functional as F
from trainer.networks import register_model
from utils.util import checkpoint, opt_get
import bitsandbytes as bnb
class Discriminator_VGG_128(nn.Module):
@ -46,8 +47,8 @@ class Discriminator_VGG_128(nn.Module):
input_img_factor = input_img_factor // 2
final_nf = nf * 16
self.linear1 = nn.Linear(final_nf * 4 * input_img_factor * 4 * input_img_factor, 100)
self.linear2 = nn.Linear(100, 1)
self.linear1 = bnb.nn.Linear8bitLt(final_nf * 4 * input_img_factor * 4 * input_img_factor, 100)
self.linear2 = bnb.nn.Linear8bitLt(100, 1)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
@ -129,8 +130,8 @@ class Discriminator_VGG_128_GN(nn.Module):
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.linear1 = nn.Linear(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 100)
self.linear2 = nn.Linear(100, 1)
self.linear1 = bnb.nn.Linear8bitLt(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 100)
self.linear2 = bnb.nn.Linear8bitLt(100, 1)
def compute_body(self, x):
fea = self.lrelu(self.conv0_0(x))
@ -219,8 +220,8 @@ class DiscriminatorVGG448GN(nn.Module):
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
final_nf = nf * 8
self.linear1 = nn.Linear(int(final_nf * 7 * 7), 100)
self.linear2 = nn.Linear(100, 1)
self.linear1 = bnb.nn.Linear8bitLt(int(final_nf * 7 * 7), 100)
self.linear2 = bnb.nn.Linear8bitLt(100, 1)
# Assign all new heads to the new param group.2
for m in [self.convn1_0, self.convn1_1, self.bnn1_1, self.conv0_0_new, self.bn0_0]:

View File

@ -2,6 +2,7 @@ import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import bitsandbytes as bnb
def initialize_weights(net_l, scale=1):
@ -14,7 +15,7 @@ def initialize_weights(net_l, scale=1):
m.weight.data *= scale # for residual block
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
elif isinstance(m, bnb.nn.Linear8bitLt):
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale
if m.bias is not None:

View File

@ -28,6 +28,7 @@ except:
APEX_AVAILABLE = False
assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.'
import bitsandbytes as bnb
num_cores = multiprocessing.cpu_count()
@ -351,7 +352,7 @@ class RGBBlock(nn.Module):
def __init__(self, latent_dim, input_channel, upsample, rgba=False):
super().__init__()
self.input_channel = input_channel
self.to_style = nn.Linear(latent_dim, input_channel)
self.to_style = bnb.nn.Linear8bitLt(latent_dim, input_channel)
out_filters = 3 if not rgba else 4
self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False)
@ -489,16 +490,16 @@ class GeneratorBlockWithStructure(nn.Module):
# Uses stylegan1 style blocks for injecting structural latent.
self.conv0 = EqualConv2d(input_channels, filters, 3, padding=1)
self.to_noise0 = nn.Linear(1, filters)
self.to_noise0 = bnb.nn.Linear8bitLt(1, filters)
self.noise0 = equal_lr(NoiseInjection(filters))
self.adain0 = AdaptiveInstanceNorm(filters, latent_dim)
self.to_style1 = nn.Linear(latent_dim, filters)
self.to_noise1 = nn.Linear(1, filters)
self.to_style1 = bnb.nn.Linear8bitLt(latent_dim, filters)
self.to_noise1 = bnb.nn.Linear8bitLt(1, filters)
self.conv1 = Conv2DMod(filters, filters, 3)
self.to_style2 = nn.Linear(latent_dim, filters)
self.to_noise2 = nn.Linear(1, filters)
self.to_style2 = bnb.nn.Linear8bitLt(latent_dim, filters)
self.to_noise2 = bnb.nn.Linear8bitLt(1, filters)
self.conv2 = Conv2DMod(filters, filters, 3)
self.activation = leaky_relu()
@ -540,12 +541,12 @@ class GeneratorBlock(nn.Module):
self.structure_conv = nn.Conv2d(3, input_channels, 3, padding=1)
input_channels = input_channels * 2
self.to_style1 = nn.Linear(latent_dim, input_channels)
self.to_noise1 = nn.Linear(1, filters)
self.to_style1 = bnb.nn.Linear8bitLt(latent_dim, input_channels)
self.to_noise1 = bnb.nn.Linear8bitLt(1, filters)
self.conv1 = Conv2DMod(input_channels, filters, 3)
self.to_style2 = nn.Linear(latent_dim, filters)
self.to_noise2 = nn.Linear(1, filters)
self.to_style2 = bnb.nn.Linear8bitLt(latent_dim, filters)
self.to_noise2 = bnb.nn.Linear8bitLt(1, filters)
self.conv2 = Conv2DMod(filters, filters, 3)
self.activation = leaky_relu()
@ -724,7 +725,7 @@ class StyleGan2GeneratorWithLatent(nn.Module):
def _init_weights(self):
for m in self.modules():
if type(m) in {nn.Conv2d, nn.Linear} and hasattr(m, 'weight'):
if type(m) in {nn.Conv2d, bnb.nn.Linear8bitLt} and hasattr(m, 'weight'):
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
for block in self.gen.blocks:
@ -804,7 +805,7 @@ class StyleGan2Discriminator(nn.Module):
self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1)
self.flatten = Flatten()
self.to_logit = nn.Linear(latent_dim, 1)
self.to_logit = bnb.nn.Linear8bitLt(latent_dim, 1)
self._init_weights()
@ -836,7 +837,7 @@ class StyleGan2Discriminator(nn.Module):
def _init_weights(self):
for m in self.modules():
if type(m) in {nn.Conv2d, nn.Linear}:
if type(m) in {nn.Conv2d, bnb.nn.Linear8bitLt}:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')

View File

@ -12,6 +12,7 @@ from torch import nn
from data.images.byol_attachment import RandomApply
from trainer.networks import register_model, create_model
from utils.util import checkpoint, opt_get
import bitsandbytes as bnb
def default(val, def_val):
@ -78,10 +79,10 @@ class MLP(nn.Module):
def __init__(self, dim, projection_size, hidden_size=4096):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_size),
bnb.nn.Linear8bitLt(dim, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, projection_size)
bnb.nn.Linear8bitLt(hidden_size, projection_size)
)
def forward(self, x):
@ -103,10 +104,10 @@ class StructuralMLP(nn.Module):
nn.BatchNorm2d(c),
nn.ReLU(inplace=True),
nn.Flatten(),
nn.Linear(flattened_dim, hidden_size),
bnb.nn.Linear8bitLt(flattened_dim, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, projection_size)
bnb.nn.Linear8bitLt(hidden_size, projection_size)
)
def forward(self, x):

View File

@ -1,6 +1,7 @@
import torch
import torch.nn as nn
import numpy as np
import bitsandbytes as bnb
__all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152']
@ -108,8 +109,8 @@ class FixupResNet(nn.Module):
self.layer4 = self._make_layer(block, num_filters*8, layers[3], stride=2)
self.bias2 = nn.Parameter(torch.zeros(1))
reduced_img_sz = int(input_img_size / 32)
self.fc1 = nn.Linear(num_filters * 8 * reduced_img_sz * reduced_img_sz, 100)
self.fc2 = nn.Linear(100, num_classes)
self.fc1 = bnb.nn.Linear8bitLt(num_filters * 8 * reduced_img_sz * reduced_img_sz, 100)
self.fc2 = bnb.nn.Linear8bitLt(100, num_classes)
for m in self.modules():
if isinstance(m, FixupBasicBlock):
@ -124,7 +125,7 @@ class FixupResNet(nn.Module):
if m.downsample is not None:
nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:]))))
'''
elif isinstance(m, nn.Linear):
elif isinstance(m, bnb.nn.Linear8bitLt):
nn.init.constant_(m.weight, 0)
nn.init.constant_(m.bias, 0)'''

View File

@ -5,6 +5,7 @@ import torch.nn.functional as F
from models.arch_util import ResBlock
from models.lucidrains.x_transformers import Encoder
from trainer.networks import register_model
import bitsandbytes as bnb
class VitLatent(nn.Module):
@ -31,10 +32,10 @@ class VitLatent(nn.Module):
do_checkpointing=True
)
self.mlp = nn.Sequential(nn.Linear(hidden_dim, hidden_dim*2),
self.mlp = nn.Sequential(bnb.nn.Linear8bitLt(hidden_dim, hidden_dim*2),
nn.BatchNorm1d(hidden_dim*2),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim*2, hidden_dim))
bnb.nn.Linear8bitLt(hidden_dim*2, hidden_dim))
def provide_ema(self, ema):
self.ema = ema

View File

@ -7,6 +7,7 @@ import torch.nn.functional as F
from einops import rearrange, repeat
from rotary_embedding_torch import apply_rotary_emb
import bitsandbytes as bnb
# helpers
@ -47,9 +48,9 @@ class Attention(nn.Module):
self.stable = stable
self.causal = causal
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_qkv = bnb.nn.Linear8bitLt(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
bnb.nn.Linear8bitLt(inner_dim, dim),
nn.Dropout(dropout)
)
@ -102,10 +103,10 @@ class SparseConvCausalAttention(nn.Module):
self.stable = stable
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_qkv = bnb.nn.Linear8bitLt(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
bnb.nn.Linear8bitLt(inner_dim, dim),
nn.Dropout(dropout)
)
@ -222,10 +223,10 @@ class SparseAxialCausalAttention(nn.Module):
self.stable = stable
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_qkv = bnb.nn.Linear8bitLt(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
bnb.nn.Linear8bitLt(inner_dim, dim),
nn.Dropout(dropout)
)

View File

@ -11,6 +11,7 @@ from models.lucidrains.dalle.attention import Attention, SparseAttention, Sparse
from rotary_embedding_torch import RotaryEmbedding, broadcat
from g_mlp_pytorch import gMLPBlock
import bitsandbytes as bnb
# helpers
@ -78,10 +79,10 @@ class FeedForward(nn.Module):
def __init__(self, dim, dropout = 0., mult = 4.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult * 2),
bnb.nn.Linear8bitLt(dim, dim * mult * 2),
GEGLU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim)
bnb.nn.Linear8bitLt(dim * mult, dim)
)
def forward(self, x):

View File

@ -21,6 +21,7 @@ try:
APEX_AVAILABLE = True
except:
APEX_AVAILABLE = False
import bitsandbytes as bnb
# helpers
@ -356,10 +357,10 @@ class FeedForward(nn.Module):
activation = default(activation, nn.GELU)
self.glu = glu
self.w1 = nn.Linear(dim, dim * mult * (2 if glu else 1))
self.w1 = bnb.nn.Linear8bitLt(dim, dim * mult * (2 if glu else 1))
self.act = activation()
self.dropout = nn.Dropout(dropout)
self.w2 = nn.Linear(dim * mult, dim)
self.w2 = bnb.nn.Linear8bitLt(dim * mult, dim)
def forward(self, x, **kwargs):
if not self.glu:
@ -401,10 +402,10 @@ class Attention(nn.Module):
self.global_heads = heads - local_heads
self.local_attn = LocalAttention(window_size = local_window_size, causal = causal, autopad = True, dropout = dropout, look_forward = int(not causal), rel_pos_emb_config = (dim_head, local_heads)) if local_heads > 0 else None
self.to_q = nn.Linear(dim, inner_dim, bias = qkv_bias)
self.to_k = nn.Linear(dim, inner_dim, bias = qkv_bias)
self.to_v = nn.Linear(dim, inner_dim, bias = qkv_bias)
self.to_out = nn.Linear(inner_dim, dim, bias = attn_out_bias)
self.to_q = bnb.nn.Linear8bitLt(dim, inner_dim, bias = qkv_bias)
self.to_k = bnb.nn.Linear8bitLt(dim, inner_dim, bias = qkv_bias)
self.to_v = bnb.nn.Linear8bitLt(dim, inner_dim, bias = qkv_bias)
self.to_out = bnb.nn.Linear8bitLt(inner_dim, dim, bias = attn_out_bias)
self.dropout = nn.Dropout(dropout)
def forward(self, x, pos_emb = None, context = None, mask = None, context_mask = None, **kwargs):
@ -458,7 +459,8 @@ class CrossAttention(Attention):
class AbsolutePositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len):
super().__init__()
self.emb = nn.Embedding(max_seq_len, dim)
# nn.Embedding
self.emb = bnb.nn.StableEmbedding(max_seq_len, dim)
def forward(self, x):
t = torch.arange(x.shape[1], device=x.device)
@ -619,7 +621,8 @@ class PerformerLM(nn.Module):
local_attn_heads = cast_tuple(local_attn_heads)
self.max_seq_len = max_seq_len
self.token_emb = nn.Embedding(num_tokens, dim)
# nn.Embedding
self.token_emb = bnb.nn.StableEmbedding(num_tokens, dim)
if rotary_position_emb:
self.pos_emb = FixedPositionalEmbedding(dim, max_seq_len)
@ -636,7 +639,7 @@ class PerformerLM(nn.Module):
self.performer = Performer(dim, depth, heads, dim_head, local_attn_heads, local_window_size, causal, ff_mult, nb_features, feature_redraw_interval, reversible, ff_chunks, generalized_attention, kernel_fn, use_scalenorm, use_rezero, ff_glu, ff_dropout, attn_dropout, cross_attend, no_projection, auto_check_redraw, qkv_bias, attn_out_bias, shift_tokens)
self.norm = nn.LayerNorm(dim)
self.to_out = nn.Linear(dim, num_tokens) if not tie_embed else None
self.to_out = bnb.nn.Linear8bitLt(dim, num_tokens) if not tie_embed else None
def check_redraw_projections(self):
self.performer.check_redraw_projections()

View File

@ -8,6 +8,7 @@ from torch.cuda.amp import autocast
from einops import rearrange, repeat
from contextlib import contextmanager
import bitsandbytes as bnb
def par(t, nm):
@ -355,9 +356,9 @@ class VectorQuantize(nn.Module):
codebook_dim = default(codebook_dim, dim)
requires_projection = codebook_dim != dim
self.project_in = nn.Linear(dim, codebook_dim) if requires_projection \
self.project_in = bnb.nn.Linear8bitLt(dim, codebook_dim) if requires_projection \
else nn.Identity()
self.project_out = nn.Linear(codebook_dim, dim) if requires_projection \
self.project_out = bnb.nn.Linear8bitLt(codebook_dim, dim) if requires_projection \
else nn.Identity()
self.eps = eps

View File

@ -11,6 +11,7 @@ from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
from torch.utils.checkpoint import checkpoint
import bitsandbytes as bnb
DEFAULT_DIM_HEAD = 64
@ -125,7 +126,8 @@ class AbsolutePositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len):
super().__init__()
self.scale = dim ** -0.5
self.emb = nn.Embedding(max_seq_len, dim)
# nn.Embedding
self.emb = bnb.nn.StableEmbedding(max_seq_len, dim)
def forward(self, x):
n = torch.arange(x.shape[1], device=x.device)
@ -154,7 +156,8 @@ class RelativePositionBias(nn.Module):
self.causal = causal
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
# nn.Embedding
self.relative_attention_bias = bnb.nn.StableEmbedding(num_buckets, heads)
@staticmethod
def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
@ -360,7 +363,7 @@ class RMSScaleShiftNorm(nn.Module):
self.cdim = 1
self.pdim = -1
else:
self.scale_shift_process = nn.Linear(embed_dim, dim * 2, bias=bias)
self.scale_shift_process = bnb.nn.Linear8bitLt(embed_dim, dim * 2, bias=bias)
self.cdim = -1
self.pdim = 1
@ -447,7 +450,7 @@ class GLU(nn.Module):
def __init__(self, dim_in, dim_out, activation):
super().__init__()
self.act = activation
self.proj = nn.Linear(dim_in, dim_out * 2)
self.proj = bnb.nn.Linear8bitLt(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
@ -472,7 +475,7 @@ class FeedForward(nn.Module):
activation = ReluSquared() if relu_squared else nn.GELU()
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
bnb.nn.Linear8bitLt(dim, inner_dim),
activation
) if not glu else GLU(dim, inner_dim, activation)
@ -480,7 +483,7 @@ class FeedForward(nn.Module):
project_in,
nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
bnb.nn.Linear8bitLt(inner_dim, dim_out)
)
# init last linear layer to 0
@ -535,16 +538,16 @@ class Attention(nn.Module):
qk_dim = int(collab_compression * qk_dim)
self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim))
self.to_q = nn.Linear(dim, qk_dim, bias=False)
self.to_k = nn.Linear(dim, qk_dim, bias=False)
self.to_v = nn.Linear(dim, v_dim, bias=False)
self.to_q = bnb.nn.Linear8bitLt(dim, qk_dim, bias=False)
self.to_k = bnb.nn.Linear8bitLt(dim, qk_dim, bias=False)
self.to_v = bnb.nn.Linear8bitLt(dim, v_dim, bias=False)
self.dropout = nn.Dropout(dropout)
# add GLU gating for aggregated values, from alphafold2
self.to_v_gate = None
if gate_values:
self.to_v_gate = nn.Linear(dim, v_dim)
self.to_v_gate = bnb.nn.Linear8bitLt(dim, v_dim)
nn.init.constant_(self.to_v_gate.weight, 0)
nn.init.constant_(self.to_v_gate.bias, 1)
@ -581,7 +584,7 @@ class Attention(nn.Module):
# attention on attention
self.attn_on_attn = on_attn
out_dim = default(out_dim, dim)
self.to_out = nn.Sequential(nn.Linear(v_dim, out_dim * 2), nn.GLU()) if on_attn else nn.Linear(v_dim, out_dim)
self.to_out = nn.Sequential(bnb.nn.Linear8bitLt(v_dim, out_dim * 2), nn.GLU()) if on_attn else bnb.nn.Linear8bitLt(v_dim, out_dim)
self.rel_pos_bias = rel_pos_bias
if rel_pos_bias:
@ -1077,7 +1080,7 @@ class ViTransformerWrapper(nn.Module):
self.patch_size = patch_size
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_to_embedding = nn.Linear(patch_dim, dim)
self.patch_to_embedding = bnb.nn.Linear8bitLt(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
@ -1135,18 +1138,19 @@ class TransformerWrapper(nn.Module):
self.max_mem_len = max_mem_len
self.shift_mem_down = shift_mem_down
self.token_emb = nn.Embedding(num_tokens, emb_dim)
# nn.Embedding
self.token_emb = bnb.nn.StableEmbedding(num_tokens, emb_dim)
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
self.emb_dropout = nn.Dropout(emb_dropout)
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
self.project_emb = bnb.nn.Linear8bitLt(emb_dim, dim) if emb_dim != dim else nn.Identity()
self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim)
self.init_()
self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
self.to_logits = bnb.nn.Linear8bitLt(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
# memory tokens (like [cls]) from Memory Transformers paper
num_memory_tokens = default(num_memory_tokens, 0)
@ -1233,12 +1237,12 @@ class ContinuousTransformerWrapper(nn.Module):
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
self.emb_dropout = nn.Dropout(emb_dropout)
self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
self.project_in = bnb.nn.Linear8bitLt(dim_in, dim) if exists(dim_in) else nn.Identity()
self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim)
self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
self.project_out = bnb.nn.Linear8bitLt(dim, dim_out) if exists(dim_out) else nn.Identity()
def forward(
self,

View File

@ -4,13 +4,15 @@ import torch.nn.functional as F
from torch import einsum
from utils.weight_scheduler import LinearDecayWeightScheduler
import bitsandbytes as bnb
class GumbelQuantizer(nn.Module):
def __init__(self, inp_dim, codebook_dim, num_tokens, straight_through=False):
super().__init__()
self.to_logits = nn.Conv1d(inp_dim, num_tokens, 1)
self.codebook = nn.Embedding(num_tokens, codebook_dim)
# nn.Embedding
self.codebook = bnb.nn.StableEmbedding(num_tokens, codebook_dim)
self.straight_through = straight_through
self.temperature_scheduler = LinearDecayWeightScheduler(10, 5000, .9, 2000)
self.step = 0

View File

@ -4,6 +4,7 @@ import torch.nn.functional as F
from einops import rearrange, repeat
from models.arch_util import l2norm, sample_vectors, default, ema_inplace
import bitsandbytes as bnb
def kmeans(samples, num_clusters, num_iters = 10, use_cosine_sim = False):
@ -184,8 +185,8 @@ class VectorQuantize(nn.Module):
codebook_dim = default(codebook_dim, dim)
requires_projection = codebook_dim != dim
self.project_in = nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
self.project_out = nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
self.project_in = bnb.nn.Linear8bitLt(dim, codebook_dim) if requires_projection else nn.Identity()
self.project_out = bnb.nn.Linear8bitLt(codebook_dim, dim) if requires_projection else nn.Identity()
self.eps = eps

View File

@ -21,6 +21,7 @@ import torchvision.utils as utils
from utils.loss_accumulator import LossAccumulator, InfStorageLossAccumulator
from utils.util import opt_get, denormalize
import bitsandbytes as bnb
logger = logging.getLogger('base')
@ -337,7 +338,7 @@ class ExtensibleTrainer(BaseModel):
for net in self.networks.values():
for mod in net.modules():
fan_in = -1
if isinstance(mod, nn.Linear):
if isinstance(mod, bnb.nn.Linear8bitLt):
fan_in = mod.weight.data.shape[1]
elif isinstance(mod, nn.Conv1d):
fan_in = mod.weight.data.shape[0]

View File

@ -1,3 +1,4 @@
import logging
from collections import OrderedDict
@ -6,6 +7,7 @@ import torch.nn as nn
import trainer.networks as networks
import trainer.lr_scheduler as lr_scheduler
from .base_model import BaseModel
import bitsandbytes as bnb
logger = logging.getLogger('base')
@ -40,7 +42,8 @@ class FeatureModel(BaseModel):
else:
if self.rank <= 0:
logger.warning('Params [{:s}] will not optimize.'.format(k))
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
# torch.optim.Adam
self.optimizer_G = bnb.optim.Adam8bit(optim_params, lr=train_opt['lr_G'],
weight_decay=wd_G,
betas=(train_opt['beta1_G'], train_opt['beta2_G']))
self.optimizers.append(self.optimizer_G)

View File

@ -3,10 +3,10 @@ from collections import Counter
from collections import defaultdict
import torch
from torch.optim.lr_scheduler import _LRScheduler
import bitsandbytes as bnb
from utils.util import opt_get
def get_scheduler_for_name(name, optimizers, scheduler_opt):
schedulers = []
for o in optimizers:
@ -136,7 +136,8 @@ class CosineAnnealingLR_Restart(_LRScheduler):
if __name__ == "__main__":
optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=1e-4, weight_decay=0,
#torch.optim.Adam
optimizer = bnb.optim.Adam8bit([torch.zeros(3, 64, 3, 3)], lr=1e-4, weight_decay=0,
betas=(0.9, 0.99))
##############################
# MultiStepLR_Restart

View File

@ -12,6 +12,7 @@ from utils.util import recursively_detach, opt_get, clip_grad_norm
logger = logging.getLogger('base')
import bitsandbytes as bnb
# Defines the expected API for a single training step
class ConfigurableStep(Module):
@ -82,7 +83,8 @@ class ConfigurableStep(Module):
import torch.nn as nn
norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d,
nn.BatchNorm3d, nn.InstanceNorm3d, nn.GroupNorm, nn.LayerNorm)
emb_modules = (nn.Embedding, nn.EmbeddingBag)
# nn.Embedding
emb_modules = (bnb.nn.StableEmbedding, nn.EmbeddingBag)
param_names_notweights = set()
all_param_names = set()
param_map = {}
@ -123,7 +125,8 @@ class ConfigurableStep(Module):
{ 'params': params_weights, 'weight_decay': opt_get(opt_config, ['weight_decay'], 0) },
{ 'params': params_notweights, 'weight_decay': 0 }
]
opt = torch.optim.AdamW(groups, lr=opt_config['lr'],
# torch.optim.AdamW
opt = bnb.optim.AdamW8bit(groups, lr=opt_config['lr'],
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
opt._group_names = [params_names_weights, params_names_notweights]
@ -141,14 +144,16 @@ class ConfigurableStep(Module):
# The torch ZeRO implementation does not seem to support parameter groups, so do not shard the non-weighted
# parameters and just use a normal AdamW implementation. In a large network, these weights will normally
# be a tiny fraction of the total weights.
opt_unweighted = torch.optim.AdamW(params_notweights, lr=opt_config['lr'], weight_decay=0,
# torch.optim.AdamW
opt_unweighted = bnb.optim.AdamW8bit(params_notweights, lr=opt_config['lr'], weight_decay=0,
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
opt_unweighted._config = opt_config
opt_unweighted._config['network'] = net_name
opt_unweighted._group_names = []
self.optimizers.append(opt_unweighted)
opt = ZeroRedundancyOptimizer(params_weights, optimizer_class=torch.optim.AdamW, lr=opt_config['lr'],
# torch.optim.AdamW
opt = ZeroRedundancyOptimizer(params_weights, optimizer_class=bnb.optim.AdamW8bit, lr=opt_config['lr'],
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
opt.param_groups[0]['initial_lr'] = opt_config['lr']
@ -162,7 +167,8 @@ class ConfigurableStep(Module):
opt._group_names = sorted(list(all_param_names))
elif self.step_opt['optimizer'] == 'lamb':
from trainer.optimizers.lamb import Lamb
opt_unweighted = torch.optim.AdamW(params_notweights, lr=opt_config['lr'], weight_decay=0,
# torch.optim.AdamW
opt_unweighted = bnb.optim.AdamW8bit(params_notweights, lr=opt_config['lr'], weight_decay=0,
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
opt_unweighted._config = opt_config
opt_unweighted._config['network'] = net_name