I sucked off the hyptothetical wizard again, just using BNB's ADAM optimizer nets HUGE savings, but I don't know the output costs, will need to test

This commit is contained in:
mrq 2023-02-23 02:42:17 +00:00
parent 01c0941a40
commit 6676c89c0e
68 changed files with 317 additions and 276 deletions

View File

@ -9,7 +9,7 @@ import torch.nn.utils.spectral_norm as SpectralNorm
from math import sqrt
from utils.util import checkpoint
import bitsandbytes as bnb
import torch_intermediary as ml
def exists(val):
@ -74,7 +74,7 @@ def initialize_weights(net_l, scale=1):
m.weight.data *= scale # for residual block
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, bnb.nn.Linear8bitLt):
elif isinstance(m, ml.Linear):
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale
if m.bias is not None:
@ -109,7 +109,7 @@ def default_init_weights(module, scale=1):
if isinstance(m, nn.Conv2d):
kaiming_init(m, a=0, mode='fan_in', bias=0)
m.weight.data *= scale
elif isinstance(m, bnb.nn.Linear8bitLt):
elif isinstance(m, ml.Linear):
kaiming_init(m, a=0, mode='fan_in', bias=0)
m.weight.data *= scale
@ -142,7 +142,7 @@ def linear(*args, **kwargs):
"""
Create a linear module.
"""
return bnb.nn.Linear8bitLt(*args, **kwargs)
return ml.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):

View File

@ -9,7 +9,7 @@ from data.audio.unsupervised_audio_dataset import load_audio
from models.audio.tts.tacotron2.text import sequence_to_text
from trainer.networks import register_model
from utils.util import opt_get
import bitsandbytes as bnb
import torch_intermediary as ml
def only_letters(string):
@ -52,7 +52,7 @@ class Wav2VecWrapper(nn.Module):
self.w2v = Wav2Vec2ForCTC.from_pretrained(basis_model)
# Perform some surgery to get the model we actually want.
self.w2v.wav2vec2.encoder.gradient_checkpointing = checkpointing_enabled
self.w2v.lm_head = bnb.nn.Linear8bitLt(self.w2v.config.hidden_size, vocab_size)
self.w2v.lm_head = ml.Linear(self.w2v.config.hidden_size, vocab_size)
self.w2v.config.vocab_size = vocab_size
self.w2v.config.pad_token_id = 0
self.w2v.config.ctc_loss_reduction = 'sum'

View File

@ -5,7 +5,7 @@ import torch.nn as nn
from trainer.networks import register_model
from utils.util import opt_get
from typing import Type, Any, Callable, Union, List, Optional
import bitsandbytes as bnb
import torch_intermediary as ml
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
@ -173,7 +173,7 @@ class ResNet(nn.Module):
self.layer4 = self._make_layer(block, 512, layers[3], stride=4,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.fc = bnb.nn.Linear8bitLt(512 * block.expansion, num_classes)
self.fc = ml.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv1d):

View File

@ -15,14 +15,14 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
from models.arch_util import ResBlock
from trainer.networks import register_model
from utils.util import checkpoint
import bitsandbytes as bnb
import torch_intermediary as ml
class Mel2Vec2FeatureProjection(nn.Module):
def __init__(self, inner_dim, dropout):
super().__init__()
self.layer_norm = nn.LayerNorm(inner_dim, eps=1e-5)
self.projection = bnb.nn.Linear8bitLt(inner_dim, inner_dim)
self.projection = ml.Linear(inner_dim, inner_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, hidden_states):
@ -59,10 +59,10 @@ class Wav2Vec2Attention(nn.Module):
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.k_proj = bnb.nn.Linear8bitLt(embed_dim, embed_dim, bias=bias)
self.v_proj = bnb.nn.Linear8bitLt(embed_dim, embed_dim, bias=bias)
self.q_proj = bnb.nn.Linear8bitLt(embed_dim, embed_dim, bias=bias)
self.out_proj = bnb.nn.Linear8bitLt(embed_dim, embed_dim, bias=bias)
self.k_proj = ml.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = ml.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = ml.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = ml.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
@ -183,10 +183,10 @@ class Wav2Vec2FeedForward(nn.Module):
super().__init__()
self.intermediate_dropout = nn.Dropout(dropout)
self.intermediate_dense = bnb.nn.Linear8bitLt(hidden_size, intermediate_size)
self.intermediate_dense = ml.Linear(hidden_size, intermediate_size)
self.intermediate_act_fn = F.gelu
self.output_dense = bnb.nn.Linear8bitLt(intermediate_size, hidden_size)
self.output_dense = ml.Linear(intermediate_size, hidden_size)
self.output_dropout = nn.Dropout(dropout)
def forward(self, hidden_states):
@ -430,7 +430,7 @@ class Mel2Vec(nn.Module):
k = math.sqrt(1 / module.projection.in_features)
nn.init.uniform_(module.projection.weight, a=-k, b=k)
nn.init.uniform_(module.projection.bias, a=-k, b=k)
elif isinstance(module, bnb.nn.Linear8bitLt):
elif isinstance(module, ml.Linear):
if self.disable_custom_linear_init:
return
module.weight.data.normal_(mean=0.0, std=self.linear_init_scale)
@ -511,7 +511,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
self.codevectors = nn.Parameter(
torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups)
)
self.weight_proj = bnb.nn.Linear8bitLt(proj_dim, self.num_groups * self.num_vars)
self.weight_proj = ml.Linear(proj_dim, self.num_groups * self.num_vars)
# can be decayed for training
self.temperature = 2
@ -607,8 +607,8 @@ class ContrastiveTrainingWrapper(nn.Module):
self.inp_length_factor = inp_length_multiplier
# make sure that project_hid & project_q are initialized like normal linear layers
self.project_hid = bnb.nn.Linear8bitLt(inner_dim, self.quantizer.codevector_dim)
self.project_q = bnb.nn.Linear8bitLt(self.quantizer.codevector_dim, self.quantizer.codevector_dim)
self.project_hid = ml.Linear(inner_dim, self.quantizer.codevector_dim)
self.project_q = ml.Linear(self.quantizer.codevector_dim, self.quantizer.codevector_dim)
self.reconstruction = do_reconstruction_loss
if do_reconstruction_loss:

View File

@ -2,7 +2,7 @@ import torch
import torch.nn.functional as F
from torch import nn
from transformers import GPT2Config, GPT2Model
import bitsandbytes as bnb
import torch_intermediary as ml
from models.arch_util import AttentionBlock, ResBlock
from models.audio.tts.lucidrains_dvae import DiscreteVAE
@ -57,8 +57,8 @@ class ConditioningAR(nn.Module):
del self.gpt.wte # Unused, we'll do our own embeddings.
# nn.Embedding
self.embeddings = bnb.nn.StableEmbedding(num_vectors, dim)
self.head = bnb.nn.Linear8bitLt(dim, num_vectors)
self.embeddings = ml.Embedding(num_vectors, dim)
self.head = ml.Linear(dim, num_vectors)
def forward(self, cheater_codes, conditioning, code_lengths=None, return_latent=False):
unused_params = []

View File

@ -17,7 +17,7 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import bitsandbytes as bnb
import torch_intermediary as ml
from math import sqrt
@ -25,7 +25,7 @@ from torch.utils.checkpoint import checkpoint
from trainer.networks import register_model
Linear = bnb.nn.Linear8bitLt
Linear = ml.Linear
ConvTranspose2d = nn.ConvTranspose2d

View File

@ -4,7 +4,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import autocast
import bitsandbytes as bnb
import torch_intermediary as ml
from models.arch_util import ResBlock
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
@ -24,7 +24,7 @@ class MultiGroupEmbedding(nn.Module):
def __init__(self, tokens, groups, dim):
super().__init__()
# nn.Embedding
self.m = nn.ModuleList([bnb.nn.StableEmbedding(tokens, dim // groups) for _ in range(groups)])
self.m = nn.ModuleList([ml.Embedding(tokens, dim // groups) for _ in range(groups)])
def forward(self, x):
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
@ -161,7 +161,7 @@ class FlatDiffusion(nn.Module):
# transformer network.
if in_groups is None:
# nn.Embedding
self.embeddings = bnb.nn.StableEmbedding(token_count, model_channels)
self.embeddings = ml.Embedding(token_count, model_channels)
else:
self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels)
self.latent_conditioner = nn.Sequential(

View File

@ -2,7 +2,7 @@ import torch
from torch import nn
import torch.nn.functional as F
from transformers import GPT2Config, GPT2Model
import bitsandbytes as bnb
import torch_intermediary as ml
from models.arch_util import AttentionBlock, ResBlock
from models.audio.music.music_quantizer import MusicQuantizer
@ -138,8 +138,8 @@ class GptMusicLower(nn.Module):
del self.gpt.wte # Unused, we'll do our own embeddings.
# nn.Embedding
self.embeddings = nn.ModuleList([bnb.nn.StableEmbedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)])
self.heads = nn.ModuleList([bnb.nn.Linear8bitLt(dim, num_target_vectors) for _ in range(num_vaes)])
self.embeddings = nn.ModuleList([ml.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)])
self.heads = nn.ModuleList([ml.Linear(dim, num_target_vectors) for _ in range(num_vaes)])
def forward(self, mel, conditioning, return_latent=False):
unused_params = []
@ -241,8 +241,8 @@ class GptMusicUpper(nn.Module):
del self.gpt.wte # Unused, we'll do our own embeddings.
# nn.Embedding
self.embeddings = nn.ModuleList([bnb.nn.StableEmbedding(num_upper_vectors, dim // num_upper_groups) for _ in range(num_upper_groups)])
self.heads = nn.ModuleList([bnb.nn.Linear8bitLt(dim, num_upper_vectors) for _ in range(num_upper_groups)])
self.embeddings = nn.ModuleList([ml.Embedding(num_upper_vectors, dim // num_upper_groups) for _ in range(num_upper_groups)])
self.heads = nn.ModuleList([ml.Linear(dim, num_upper_vectors) for _ in range(num_upper_groups)])
def forward(self, mel, conditioning, return_latent=False):

View File

@ -2,7 +2,7 @@ import torch
import torch.nn.functional as F
from torch import nn
from transformers import GPT2Config, GPT2Model
import bitsandbytes as bnb
import torch_intermediary as ml
from models.arch_util import AttentionBlock, ResBlock
from models.audio.tts.lucidrains_dvae import DiscreteVAE
@ -75,8 +75,8 @@ class GptMusicLower(nn.Module):
del self.gpt.wte # Unused, we'll do our own embeddings.
# nn.Embedding
self.embeddings = nn.ModuleList([bnb.nn.StableEmbedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)])
self.heads = nn.ModuleList([bnb.nn.Linear8bitLt(dim, num_target_vectors) for _ in range(num_vaes)])
self.embeddings = nn.ModuleList([ml.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)])
self.heads = nn.ModuleList([ml.Linear(dim, num_target_vectors) for _ in range(num_vaes)])
def forward(self, mel, return_latent=False):
unused_params = []

View File

@ -3,7 +3,7 @@ import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import bitsandbytes as bnb
import torch_intermediary as ml
from models.diffusion.nn import timestep_embedding
from models.lucidrains.vq import VectorQuantize
@ -22,8 +22,8 @@ class SelfClassifyingHead(nn.Module):
use_rmsnorm=True, ff_glu=True, do_checkpointing=False)
self.quantizer = VectorQuantize(out_dim, classes, use_cosine_sim=False, threshold_ema_dead_code=2,
sample_codebook_temp=init_temperature)
self.to_output = bnb.nn.Linear8bitLt(dim, out_dim)
self.to_decoder = bnb.nn.Linear8bitLt(out_dim, dim)
self.to_output = ml.Linear(dim, out_dim)
self.to_decoder = ml.Linear(out_dim, dim)
def do_ar_step(self, x, used_codes):
h = self.dec(x)
@ -91,7 +91,7 @@ class InstrumentQuantizer(nn.Module):
"""
super().__init__()
self.op_dim = op_dim
self.proj = bnb.nn.Linear8bitLt(op_dim, dim)
self.proj = ml.Linear(op_dim, dim)
self.encoder = nn.ModuleList([VectorResBlock(dim, dropout) for _ in range(enc_depth)])
self.heads = SelfClassifyingHead(dim, num_classes, op_dim, head_depth, class_seq_len, dropout, max_temp)
self.min_gumbel_temperature = min_temp

View File

@ -1,7 +1,7 @@
import torch
from torch import nn
import torch.nn.functional as F
import bitsandbytes as bnb
import torch_intermediary as ml
from transformers import GPT2Config, GPT2Model
from trainer.networks import register_model
@ -19,8 +19,8 @@ class Mel2VecCodesGpt(nn.Module):
self.gpt = GPT2Model(self.config)
del self.gpt.wte # Unused, we'll do our own embeddings.
# nn.Embedding
self.embeddings = nn.ModuleList([bnb.nn.StableEmbedding(num_vectors, dim//num_groups) for _ in range(num_groups)])
self.heads = nn.ModuleList([bnb.nn.Linear8bitLt(dim, num_vectors) for _ in range(num_groups)])
self.embeddings = nn.ModuleList([ml.Embedding(num_vectors, dim//num_groups) for _ in range(num_groups)])
self.heads = nn.ModuleList([ml.Linear(dim, num_vectors) for _ in range(num_groups)])
def forward(self, codes):
assert codes.shape[-1] == self.num_groups

View File

@ -3,7 +3,7 @@ import functools
import torch
from torch import nn
import torch.nn.functional as F
import bitsandbytes as bnb
import torch_intermediary as ml
from models.arch_util import zero_module
from models.vqvae.vqvae import Quantize
@ -76,7 +76,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
self.codevectors = nn.Parameter(
torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups)
)
self.weight_proj = bnb.nn.Linear8bitLt(proj_dim, self.num_groups * self.num_vars)
self.weight_proj = ml.Linear(proj_dim, self.num_groups * self.num_vars)
# can be decayed for training
self.temperature = 2

View File

@ -3,7 +3,7 @@ import functools
import torch
from torch import nn
import torch.nn.functional as F
import bitsandbytes as bnb
import torch_intermediary as ml
from models.arch_util import zero_module
from models.vqvae.vqvae import Quantize
@ -88,7 +88,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
self.codevectors = nn.Parameter(
torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups)
)
self.weight_proj = bnb.nn.Linear8bitLt(proj_dim, self.num_groups * self.num_vars)
self.weight_proj = ml.Linear(proj_dim, self.num_groups * self.num_vars)
# can be decayed for training
self.temperature = 2

View File

@ -8,7 +8,7 @@ import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import torchvision
import bitsandbytes as bnb
import torch_intermediary as ml
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import TimestepBlock
@ -56,12 +56,12 @@ class ConcatAttentionBlock(TimestepBlock):
self.prenorm = RMSScaleShiftNorm(trunk_dim, embed_dim=time_embed_dim, bias=False)
if cond_projection:
self.tdim = trunk_dim+cond_dim_hidden
self.cond_project = bnb.nn.Linear8bitLt(cond_dim_in, cond_dim_hidden)
self.cond_project = ml.Linear(cond_dim_in, cond_dim_hidden)
else:
self.tdim = trunk_dim
self.block1 = SubBlock(self.tdim, contraction_dim, heads, dropout, use_conv)
self.block2 = SubBlock(self.tdim+contraction_dim*2, contraction_dim, heads, dropout, use_conv)
self.out = bnb.nn.Linear8bitLt(contraction_dim*4, trunk_dim, bias=False)
self.out = ml.Linear(contraction_dim*4, trunk_dim, bias=False)
self.out.weight.data.zero_()
def forward(self, x, cond, timestep_emb, rotary_emb):
@ -89,7 +89,7 @@ class ConditioningEncoder(nn.Module):
self.init = nn.Conv1d(cond_dim, embedding_dim, kernel_size=1)
self.time_proj = time_proj
if time_proj:
self.time_proj = bnb.nn.Linear8bitLt(time_embed_dim, embedding_dim)
self.time_proj = ml.Linear(time_embed_dim, embedding_dim)
self.attn = Encoder(
dim=embedding_dim,
depth=attn_blocks,

View File

@ -4,7 +4,7 @@ from time import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import bitsandbytes as bnb
import torch_intermediary as ml
from models.arch_util import ResBlock
from models.audio.music.gpt_music2 import UpperEncoder, GptMusicLower
@ -29,7 +29,7 @@ class MultiGroupEmbedding(nn.Module):
def __init__(self, tokens, groups, dim):
super().__init__()
# nn.Embedding
self.m = nn.ModuleList([bnb.nn.StableEmbedding(tokens, dim // groups) for _ in range(groups)])
self.m = nn.ModuleList([ml.Embedding(tokens, dim // groups) for _ in range(groups)])
def forward(self, x):
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
@ -70,7 +70,7 @@ class ConcatAttentionBlock(TimestepBlock):
self.prenorm = RMSScaleShiftNorm(trunk_dim, embed_dim=time_embed_dim, bias=False)
self.block1 = SubBlock(trunk_dim, contraction_dim, heads, dropout)
self.block2 = SubBlock(trunk_dim+contraction_dim*2, contraction_dim, heads, dropout)
self.out = bnb.nn.Linear8bitLt(contraction_dim*4, trunk_dim, bias=False)
self.out = ml.Linear(contraction_dim*4, trunk_dim, bias=False)
self.out.weight.data.zero_()
def forward(self, x, timestep_emb, rotary_emb):
@ -131,7 +131,7 @@ class TransformerDiffusion(nn.Module):
)
prenet_heads = prenet_channels//64
self.input_converter = bnb.nn.Linear8bitLt(input_vec_dim, prenet_channels)
self.input_converter = ml.Linear(input_vec_dim, prenet_channels)
self.code_converter = Encoder(
dim=prenet_channels,
depth=prenet_layers,
@ -147,7 +147,7 @@ class TransformerDiffusion(nn.Module):
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels))
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
self.intg = bnb.nn.Linear8bitLt(prenet_channels*2, model_channels)
self.intg = ml.Linear(prenet_channels*2, model_channels)
self.layers = TimestepRotaryEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_embed_dim, num_heads, dropout) for _ in range(num_layers)])
self.out = nn.Sequential(

View File

@ -5,7 +5,7 @@ from random import randrange
import torch
import torch.nn as nn
import torch.nn.functional as F
import bitsandbytes as bnb
import torch_intermediary as ml
from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask, cGLU, \
RelativeQKBias
@ -71,13 +71,13 @@ class ConditioningEncoder(nn.Module):
attn = []
self.init = nn.Conv1d(spec_dim, hidden_dim, kernel_size=5, stride=2)
# nn.Embedding
self.resolution_embedding = bnb.nn.StableEmbedding(num_resolutions, hidden_dim)
self.resolution_embedding = ml.Embedding(num_resolutions, hidden_dim)
self.resolution_embedding.weight.data.mul(.1) # Reduces the relative influence of this embedding from the start.
for a in range(attn_blocks):
attn.append(AttentionBlock(hidden_dim, num_attn_heads, do_checkpoint=do_checkpointing))
attn.append(ResBlock(hidden_dim, dims=1, checkpointing_enabled=do_checkpointing))
self.attn = nn.Sequential(*attn)
self.out = bnb.nn.Linear8bitLt(hidden_dim, out_dim, bias=False)
self.out = ml.Linear(hidden_dim, out_dim, bias=False)
self.dim = hidden_dim
self.do_checkpointing = do_checkpointing
@ -134,7 +134,7 @@ class TransformerDiffusion(nn.Module):
linear(time_embed_dim, time_proj_dim),
)
# nn.Embedding
self.resolution_embed = bnb.nn.StableEmbedding(resolution_steps, time_proj_dim)
self.resolution_embed = ml.Embedding(resolution_steps, time_proj_dim)
self.conditioning_encoder = ConditioningEncoder(in_channels, model_channels, cond_proj_dim, resolution_steps, num_attn_heads=model_channels//64)
self.unconditioned_embedding = nn.Parameter(torch.randn(1,cond_proj_dim))

View File

@ -8,7 +8,7 @@ import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torchvision # For debugging, not actually used.
import bitsandbytes as bnb
import torch_intermediary as ml
from models.audio.music.gpt_music import GptMusicLower
from models.audio.music.music_quantizer import MusicQuantizer
@ -491,7 +491,7 @@ class UNetMusicModel(nn.Module):
)
if self.ar_prior:
self.ar_input = bnb.nn.Linear8bitLt(input_vec_dim, model_channels)
self.ar_input = ml.Linear(input_vec_dim, model_channels)
self.ar_prior_intg = Encoder(
dim=model_channels,
depth=4,
@ -505,7 +505,7 @@ class UNetMusicModel(nn.Module):
ff_mult=1,
)
else:
self.input_converter = bnb.nn.Linear8bitLt(input_vec_dim, model_channels)
self.input_converter = ml.Linear(input_vec_dim, model_channels)
self.code_converter = Encoder(
dim=model_channels,
depth=4,
@ -523,7 +523,7 @@ class UNetMusicModel(nn.Module):
if self.num_classes is not None:
# nn.Embedding
self.label_emb = bnb.nn.StableEmbedding(num_classes, time_embed_dim)
self.label_emb = ml.Embedding(num_classes, time_embed_dim)
self.use_raw_y_as_embedding = use_raw_y_as_embedding
assert not ((self.num_classes is not None) and use_raw_y_as_embedding) # These are mutually-exclusive.

View File

@ -3,7 +3,7 @@ from random import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import bitsandbytes as bnb
import torch_intermediary as ml
from models.audio.tts.unet_diffusion_tts7 import CheckpointedLayer
from models.lucidrains.x_transformers import Encoder
@ -38,11 +38,11 @@ class CtcCodeGenerator(nn.Module):
pred_codes = (max_pad+1)*(max_repeat+1)
# nn.Embedding
self.position_embedding = bnb.nn.StableEmbedding(max_length, model_dim)
self.position_embedding = ml.Embedding(max_length, model_dim)
# nn.Embedding
self.codes_embedding = bnb.nn.StableEmbedding(ctc_codes, model_dim)
self.codes_embedding = ml.Embedding(ctc_codes, model_dim)
# nn.Embedding
self.recursive_embedding = bnb.nn.StableEmbedding(pred_codes, model_dim)
self.recursive_embedding = ml.Embedding(pred_codes, model_dim)
self.mask_embedding = nn.Parameter(torch.randn(model_dim))
self.encoder = Encoder(
dim=model_dim,
@ -54,8 +54,8 @@ class CtcCodeGenerator(nn.Module):
ff_glu=True,
rotary_pos_emb=True,
)
self.pred_head = bnb.nn.Linear8bitLt(model_dim, pred_codes)
self.confidence_head = bnb.nn.Linear8bitLt(model_dim, 1)
self.pred_head = ml.Linear(model_dim, pred_codes)
self.confidence_head = ml.Linear(model_dim, 1)
def inference(self, codes, pads, repeats):
position_h = self.position_embedding(torch.arange(0, codes.shape[-1], device=codes.device))

View File

@ -5,7 +5,7 @@ from functools import partial
import torch
import torch.nn as nn
import bitsandbytes as bnb
import torch_intermediary as ml
from x_transformers.x_transformers import groupby_prefix_and_trim, FixedPositionalEmbedding, default, RotaryEmbedding, \
DEFAULT_DIM_HEAD, RelativePositionBias, LearnedAlibiPositionalBias, AlibiPositionalBias, ScaleNorm, RMSNorm, \
@ -18,7 +18,7 @@ class TimeIntegrationBlock(nn.Module):
super().__init__()
self.emb_layers = nn.Sequential(
nn.SiLU(),
bnb.nn.Linear8bitLt(
ml.Linear(
time_emb_dim,
2 * dim
),

View File

@ -1,6 +1,6 @@
import torch
import torch.nn as nn
import bitsandbytes as bnb
import torch_intermediary as ml
from models.diffusion.nn import normalization, conv_nd, zero_module
@ -139,7 +139,7 @@ class AudioMiniEncoderWithClassifierHead(nn.Module):
def __init__(self, classes, distribute_zero_label=True, **kwargs):
super().__init__()
self.enc = AudioMiniEncoder(**kwargs)
self.head = bnb.nn.Linear8bitLt(self.enc.dim, classes)
self.head = ml.Linear(self.enc.dim, classes)
self.num_classes = classes
self.distribute_zero_label = distribute_zero_label
@ -184,7 +184,7 @@ class QueryProvidedAttentionBlock(nn.Module):
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.norm = normalization(channels)
self.q = bnb.nn.Linear8bitLt(channels, channels)
self.q = ml.Linear(channels, channels)
self.qnorm = nn.LayerNorm(channels)
self.kv = conv_nd(1, channels, channels*2, 1)
if use_new_attention_order:

View File

@ -3,7 +3,7 @@ import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import bitsandbytes as bnb
import torch_intermediary as ml
from trainer.networks import register_model
from utils.util import opt_get
@ -45,7 +45,7 @@ class RandomLatentConverter(nn.Module):
def __init__(self, channels):
super().__init__()
self.layers = nn.Sequential(*[EqualLinear(channels, channels, lr_mul=.1) for _ in range(5)],
bnb.nn.Linear8bitLt(channels, channels))
ml.Linear(channels, channels))
self.channels = channels
def forward(self, ref):

View File

@ -3,13 +3,13 @@ from librosa.filters import mel as librosa_mel_fn
from models.audio.tts.tacotron2.audio_processing import dynamic_range_compression
from models.audio.tts.tacotron2.audio_processing import dynamic_range_decompression
from models.audio.tts.tacotron2.stft import STFT
import bitsandbytes as bnb
import torch_intermediary as ml
class LinearNorm(torch.nn.Module):
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
super(LinearNorm, self).__init__()
self.linear_layer = torch.bnb.nn.Linear8bitLt(in_dim, out_dim, bias=bias)
self.linear_layer = torch.ml.Linear(in_dim, out_dim, bias=bias)
torch.nn.init.xavier_uniform_(
self.linear_layer.weight,

View File

@ -8,7 +8,7 @@ from models.audio.tts.tacotron2.layers import ConvNorm, LinearNorm
from models.audio.tts.tacotron2.hparams import create_hparams
from trainer.networks import register_model
from models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths
import bitsandbytes as bnb
import torch_intermediary as ml
class LocationLayer(nn.Module):
@ -465,7 +465,7 @@ class Tacotron2(nn.Module):
self.n_mel_channels = hparams.n_mel_channels
self.n_frames_per_step = hparams.n_frames_per_step
# nn.Embedding
self.embedding = bnb.nn.StableEmbedding(
self.embedding = ml.Embedding(
hparams.n_symbols, hparams.symbols_embedding_dim)
std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
val = sqrt(3.0) * std # uniform bounds for std

View File

@ -13,7 +13,7 @@ from models.audio.tts.tacotron2.tacotron2 import Attention, Encoder
from trainer.networks import register_model
from models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths
from utils.util import checkpoint
import bitsandbytes as bnb
import torch_intermediary as ml
@ -187,7 +187,7 @@ class WaveTacotron2(nn.Module):
self.n_mel_channels = hparams.n_mel_channels
self.n_frames_per_step = hparams.n_frames_per_step
# nn.Embedding
self.embedding = bnb.nn.StableEmbedding(
self.embedding = ml.Embedding(
hparams.n_symbols, hparams.symbols_embedding_dim)
std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
val = sqrt(3.0) * std # uniform bounds for std

View File

@ -25,7 +25,7 @@ import random
from time import time
import torch
import torch.nn as nn
import bitsandbytes as bnb
import torch_intermediary as ml
from tqdm import tqdm
@ -37,7 +37,7 @@ class LearnedPositionEmbeddings(nn.Module):
def __init__(self, seq_len, model_dim, init=.02, relative=False):
super().__init__()
# nn.Embedding
self.emb = bnb.nn.StableEmbedding(seq_len, model_dim)
self.emb = ml.Embedding(seq_len, model_dim)
# Initializing this way is standard for GPT-2
self.emb.weight.data.normal_(mean=0.0, std=init)
self.relative = relative

View File

@ -7,7 +7,7 @@ from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlo
from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding
from trainer.networks import register_model
from utils.util import checkpoint
import bitsandbytes as bnb
import torch_intermediary as ml
def is_latent(t):
@ -21,7 +21,7 @@ class MultiGroupEmbedding(nn.Module):
def __init__(self, tokens, groups, dim):
super().__init__()
# nn.Embedding
self.m = nn.ModuleList([bnb.nn.StableEmbedding(tokens, dim // groups) for _ in range(groups)])
self.m = nn.ModuleList([ml.Embedding(tokens, dim // groups) for _ in range(groups)])
def forward(self, x):
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
@ -102,9 +102,9 @@ class TransformerDiffusionTTS(nn.Module):
ff_glu=True,
rotary_pos_emb=True,
)
self.clvp_encoder = bnb.nn.Linear8bitLt(clvp_in_dim, model_channels)
self.clvp_encoder = ml.Linear(clvp_in_dim, model_channels)
# nn.Embedding
self.type_embedding = bnb.nn.StableEmbedding(types, model_channels)
self.type_embedding = ml.Embedding(types, model_channels)
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
@ -112,7 +112,7 @@ class TransformerDiffusionTTS(nn.Module):
# transformer network.
if in_groups is None:
# nn.Embedding
self.embeddings = bnb.nn.StableEmbedding(token_count, model_channels)
self.embeddings = ml.Embedding(token_count, model_channels)
else:
self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels)
self.latent_conditioner = nn.Sequential(
@ -144,7 +144,7 @@ class TransformerDiffusionTTS(nn.Module):
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
self.intg = bnb.nn.Linear8bitLt(model_channels*2, model_channels)
self.intg = ml.Linear(model_channels*2, model_channels)
self.layers = TimestepRotaryEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers)])
self.out = nn.Sequential(

View File

@ -1,7 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import bitsandbytes as bnb
import torch_intermediary as ml
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock
@ -21,7 +21,7 @@ class MultiGroupEmbedding(nn.Module):
def __init__(self, tokens, groups, dim):
super().__init__()
# nn.Embedding
self.m = nn.ModuleList([bnb.nn.StableEmbedding(tokens, dim // groups) for _ in range(groups)])
self.m = nn.ModuleList([ml.Embedding(tokens, dim // groups) for _ in range(groups)])
def forward(self, x):
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
@ -42,7 +42,7 @@ class DietAttentionBlock(TimestepBlock):
def __init__(self, in_dim, dim, heads, dropout):
super().__init__()
self.rms_scale_norm = RMSScaleShiftNorm(in_dim)
self.proj = bnb.nn.Linear8bitLt(in_dim, dim)
self.proj = ml.Linear(in_dim, dim)
self.attn = Attention(dim, heads=heads, causal=False, dropout=dropout)
self.ff = FeedForward(dim, in_dim, mult=1, dropout=dropout, zero_init_output=True)
@ -107,9 +107,9 @@ class TransformerDiffusionTTS(nn.Module):
ff_glu=True,
rotary_pos_emb=True,
)
self.clvp_encoder = bnb.nn.Linear8bitLt(clvp_in_dim, prenet_channels)
self.clvp_encoder = ml.Linear(clvp_in_dim, prenet_channels)
# nn.Embedding
self.type_embedding = bnb.nn.StableEmbedding(types, prenet_channels)
self.type_embedding = ml.Embedding(types, prenet_channels)
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
@ -117,7 +117,7 @@ class TransformerDiffusionTTS(nn.Module):
# transformer network.
if in_groups is None:
# nn.Embedding
self.embeddings = bnb.nn.StableEmbedding(token_count, prenet_channels)
self.embeddings = ml.Embedding(token_count, prenet_channels)
else:
self.embeddings = MultiGroupEmbedding(token_count, in_groups, prenet_channels)
self.latent_conditioner = nn.Sequential(
@ -148,8 +148,8 @@ class TransformerDiffusionTTS(nn.Module):
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels))
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
self.cond_intg = bnb.nn.Linear8bitLt(prenet_channels*4, model_channels)
self.intg = bnb.nn.Linear8bitLt(prenet_channels*2, model_channels)
self.cond_intg = ml.Linear(prenet_channels*4, model_channels)
self.intg = ml.Linear(prenet_channels*2, model_channels)
self.layers = TimestepRotaryEmbedSequential(*[DietAttentionBlock(model_channels, block_channels, block_channels // 64, dropout) for _ in range(num_layers)])

View File

@ -5,7 +5,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import autocast
import bitsandbytes as bnb
import torch_intermediary as ml
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \
@ -249,7 +249,7 @@ class DiffusionTts(nn.Module):
embedding_dim = model_channels * 8
# nn.Embedding
self.code_embedding = bnb.nn.StableEmbedding(num_tokens+1, embedding_dim)
self.code_embedding = ml.Embedding(num_tokens+1, embedding_dim)
self.contextual_embedder = AudioMiniEncoder(1, embedding_dim, base_channels=32, depth=6, resnet_blocks=1,
attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
self.conditioning_conv = nn.Conv1d(embedding_dim*3, embedding_dim, 1)
@ -257,7 +257,7 @@ class DiffusionTts(nn.Module):
self.enable_unaligned_inputs = enabled_unaligned_inputs
if enabled_unaligned_inputs:
# nn.Embedding
self.unaligned_embedder = bnb.nn.StableEmbedding(num_unaligned_tokens, embedding_dim)
self.unaligned_embedder = ml.Embedding(num_unaligned_tokens, embedding_dim)
self.unaligned_encoder = CheckpointedXTransformerEncoder(
max_seq_len=-1,
use_pos_emb=False,

View File

@ -5,7 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch import autocast
from x_transformers import Encoder
import bitsandbytes as bnb
import torch_intermediary as ml
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \
@ -208,7 +208,7 @@ class DiffusionTts(nn.Module):
# transformer network.
self.code_converter = nn.Sequential(
# nn.Embedding
bnb.nn.StableEmbedding(in_tokens, conditioning_dim),
ml.Embedding(in_tokens, conditioning_dim),
CheckpointedXTransformerEncoder(
needs_permute=False,
max_seq_len=-1,

View File

@ -5,7 +5,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import autocast
import bitsandbytes as bnb
import torch_intermediary as ml
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock, QKVAttentionLegacy
@ -196,7 +196,7 @@ class DiffusionTtsFlat(nn.Module):
# transformer network.
# nn.Embedding
self.code_embedding = bnb.nn.StableEmbedding(in_tokens, model_channels)
self.code_embedding = ml.Embedding(in_tokens, model_channels)
self.code_converter = nn.Sequential(
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),

View File

@ -13,7 +13,7 @@ from models.lucidrains.x_transformers import RotaryEmbedding, apply_rotary_pos_e
from trainer.networks import register_model
from utils.util import opt_get
import bitsandbytes as bnb
import torch_intermediary as ml
class ResBlock(nn.Module):
"""
@ -282,10 +282,10 @@ class UnifiedVoice(nn.Module):
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.average_conditioning_embeddings = average_conditioning_embeddings
# nn.Embedding
self.text_embedding = bnb.nn.StableEmbedding(self.number_text_tokens, model_dim)
self.text_embedding = ml.Embedding(self.number_text_tokens, model_dim)
if use_mel_codes_as_input:
# nn.Embedding
self.mel_embedding = bnb.nn.StableEmbedding(self.number_mel_codes, model_dim)
self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim)
else:
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
@ -298,8 +298,8 @@ class UnifiedVoice(nn.Module):
self.text_solo_embedding = 0
self.final_norm = nn.LayerNorm(model_dim)
self.text_head = bnb.nn.Linear8bitLt(model_dim, self.number_text_tokens)
self.mel_head = bnb.nn.Linear8bitLt(model_dim, self.number_mel_codes)
self.text_head = ml.Linear(model_dim, self.number_text_tokens)
self.mel_head = ml.Linear(model_dim, self.number_mel_codes)
# Initialize the embeddings per the GPT-2 scheme
embeddings = [self.text_embedding]

View File

@ -1,7 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import bitsandbytes as bnb
import torch_intermediary as ml
from transformers import GPT2Config, GPT2PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
@ -274,16 +274,16 @@ class UnifiedVoice(nn.Module):
self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
# nn.Embedding
self.text_embedding = bnb.nn.StableEmbedding(self.number_text_tokens*types+1, model_dim)
self.text_embedding = ml.Embedding(self.number_text_tokens*types+1, model_dim)
# nn.Embedding
self.mel_embedding = bnb.nn.StableEmbedding(self.number_mel_codes, model_dim)
self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim)
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing)
self.final_norm = nn.LayerNorm(model_dim)
self.text_head = bnb.nn.Linear8bitLt(model_dim, self.number_text_tokens*types+1)
self.mel_head = bnb.nn.Linear8bitLt(model_dim, self.number_mel_codes)
self.aligned_head = bnb.nn.Linear8bitLt(model_dim, number_aligned_text_codes)
self.text_head = ml.Linear(model_dim, self.number_text_tokens*types+1)
self.mel_head = ml.Linear(model_dim, self.number_mel_codes)
self.aligned_head = ml.Linear(model_dim, number_aligned_text_codes)
# Initialize the embeddings per the GPT-2 scheme
embeddings = [self.text_embedding, self.mel_embedding]

View File

@ -11,7 +11,7 @@ from models.audio.tts.transformer_builders import build_hf_gpt_transformer
from models.lucidrains.x_transformers import RotaryEmbedding, apply_rotary_pos_emb
from trainer.networks import register_model
from utils.util import opt_get
import bitsandbytes as bnb
import torch_intermediary as ml
class ResBlock(nn.Module):
@ -257,16 +257,16 @@ class UnifiedVoice(nn.Module):
self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
# nn.Embedding
self.text_embedding = bnb.nn.StableEmbedding(self.number_text_tokens*types+1, model_dim)
self.text_embedding = ml.Embedding(self.number_text_tokens*types+1, model_dim)
# nn.Embedding
self.mel_embedding = bnb.nn.StableEmbedding(self.number_mel_codes, model_dim)
self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim)
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing)
self.final_norm = nn.LayerNorm(model_dim)
self.text_head = bnb.nn.Linear8bitLt(model_dim, self.number_text_tokens*types+1)
self.mel_head = bnb.nn.Linear8bitLt(model_dim, self.number_mel_codes)
self.alignment_head = bnb.nn.Linear8bitLt(model_dim, 256)
self.text_head = ml.Linear(model_dim, self.number_text_tokens*types+1)
self.mel_head = ml.Linear(model_dim, self.number_mel_codes)
self.alignment_head = ml.Linear(model_dim, 256)
if only_alignment_head:
for p in self.parameters():

View File

@ -8,7 +8,7 @@ from models.audio.tts.mini_encoder import AudioMiniEncoder
from trainer.injectors.spec_augment import spec_augment
from trainer.networks import register_model
from utils.util import opt_get
import bitsandbytes as bnb
import torch_intermediary as ml
def exists(val):
@ -37,7 +37,7 @@ class VoiceCLIP(nn.Module):
self.encoder = AudioMiniEncoder(80, encoder_output)
if pretrained_encoder_dict_path is not None:
self.encoder.load_state_dict(torch.load(pretrained_encoder_dict_path))
self.to_latent = bnb.nn.Linear8bitLt(encoder_output, dim_latent, bias=False)
self.to_latent = ml.Linear(encoder_output, dim_latent, bias=False)
self.temperature = nn.Parameter(torch.tensor(1.))
self.mel_compression_ratio = mel_compression_ratio

View File

@ -7,7 +7,7 @@ from x_transformers import Encoder, Decoder, ContinuousTransformerWrapper
from models.audio.tts.mini_encoder import AudioMiniEncoder
from trainer.networks import register_model
import bitsandbytes as bnb
import torch_intermediary as ml
class CheckpointedLayer(nn.Module):
@ -58,7 +58,7 @@ class Wav2VecMatcher(nn.Module):
self.conditioning_encoder = AudioMiniEncoder(1, model_dim, base_channels=32, depth=6, resnet_blocks=1,
attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
# nn.Embedding
self.text_embedding = bnb.nn.StableEmbedding(num_text_tokens, model_dim)
self.text_embedding = ml.Embedding(num_text_tokens, model_dim)
self.encoder = CheckpointedXTransformer(
max_seq_len=-1,
use_pos_emb=False,
@ -75,8 +75,8 @@ class Wav2VecMatcher(nn.Module):
)
self.decoder_start_embedding = nn.Parameter(torch.randn(1,1,model_dim))
self.decoder_stop_embedding = nn.Parameter(torch.randn(1,model_dim))
self.w2v_query_encoder = bnb.nn.Linear8bitLt(WAV2VEC_CHANNELS, model_dim)
self.w2v_value_encoder = bnb.nn.Linear8bitLt(WAV2VEC_CHANNELS, model_dim)
self.w2v_query_encoder = ml.Linear(WAV2VEC_CHANNELS, model_dim)
self.w2v_value_encoder = ml.Linear(WAV2VEC_CHANNELS, model_dim)
self.decoder = CheckpointedXTransformer(
max_seq_len=-1, # Should be unused
use_pos_emb=False,

View File

@ -10,7 +10,7 @@
import torch
import torch.nn as nn
import bitsandbytes as bnb
import torch_intermediary as ml
from trainer.networks import register_model
@ -99,7 +99,7 @@ class ResNet(nn.Module):
self.conv4_x = self._make_layer(block, 128, num_block[2], 2)
self.conv5_x = self._make_layer(block, 256, num_block[3], 2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = bnb.nn.Linear8bitLt(256 * block.expansion, num_classes)
self.fc = ml.Linear(256 * block.expansion, num_classes)
def _make_layer(self, block, out_channels, num_blocks, stride):
"""make resnet layers(by layer i didnt mean this 'layer' was the

View File

@ -4,7 +4,7 @@ import torch
import torch.nn as nn
from torchvision.models.resnet import BasicBlock, Bottleneck
import torchvision
import bitsandbytes as bnb
import torch_intermediary as ml
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
@ -195,5 +195,5 @@ def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
def register_resnet50(opt_net, opt):
model = resnet50(pretrained=opt_net['pretrained'])
if opt_net['custom_head_logits']:
model.fc = bnb.nn.Linear8bitLt(512 * 4, opt_net['custom_head_logits'])
model.fc = ml.Linear(512 * 4, opt_net['custom_head_logits'])
return model

View File

@ -11,7 +11,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import bitsandbytes as bnb
import torch_intermediary as ml
from trainer.networks import register_model
@ -102,7 +102,7 @@ class ResNet(nn.Module):
self.conv4_x = self._make_layer(block, 128, num_block[2], 2)
self.conv5_x = self._make_layer(block, 256, num_block[3], 2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = bnb.nn.Linear8bitLt(256 * block.expansion, num_classes)
self.fc = ml.Linear(256 * block.expansion, num_classes)
def _make_layer(self, block, out_channels, num_blocks, stride):
"""make resnet layers(by layer i didnt mean this 'layer' was the

View File

@ -11,7 +11,7 @@ __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
from models.vqvae.scaled_weight_conv import ScaledWeightConv
from trainer.networks import register_model
from utils.util import checkpoint
import bitsandbytes as bnb
import torch_intermediary as ml
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
@ -214,7 +214,7 @@ class ResNet(nn.Module):
self.layer4 = self._make_layer(block, 512, layers[3], breadth, stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = bnb.nn.Linear8bitLt(512 * block.expansion, num_classes)
self.fc = ml.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, ScaledWeightConv):

View File

@ -3,7 +3,7 @@ import torch.nn as nn
from trainer.networks import register_model
from utils.util import opt_get
import bitsandbytes as bnb
import torch_intermediary as ml
class WideKernelVgg(nn.Module):
def __init__(self, nf=64, num_classes=2):
@ -49,9 +49,9 @@ class WideKernelVgg(nn.Module):
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.Flatten(),
bnb.nn.Linear8bitLt(nf * 8 * 4 * 2, 100),
ml.Linear(nf * 8 * 4 * 2, 100),
nn.ReLU(),
bnb.nn.Linear8bitLt(100, num_classes)
ml.Linear(100, num_classes)
)
# These normalization constants should be derived experimentally.

View File

@ -10,7 +10,7 @@ from models.arch_util import AttentionBlock
from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder
from trainer.networks import register_model
from utils.util import opt_get, checkpoint
import bitsandbytes as bnb
import torch_intermediary as ml
def exists(val):
@ -60,7 +60,7 @@ class ConvFormatEmbedding(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
# nn.Embedding
self.emb = bnb.nn.StableEmbedding(*args, **kwargs)
self.emb = ml.Embedding(*args, **kwargs)
def forward(self, x):
y = self.emb(x)
@ -101,9 +101,9 @@ class CLVP(nn.Module):
self.mask_conditioning_percentage = mask_conditioning_percentage
# nn.Embedding
self.text_emb = bnb.nn.StableEmbedding(num_text_tokens, model_dim)
self.text_emb = ml.Embedding(num_text_tokens, model_dim)
self.text_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, text_enc_depth, text_mask_percentage, use_rms_scaleshift_norm=True)
self.to_text_latent = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False)
self.to_text_latent = ml.Linear(latent_dim, latent_dim, bias=False)
self.distributed_collect = distributed_collect
if mel_codes is None:
@ -111,7 +111,7 @@ class CLVP(nn.Module):
else:
self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim)
self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage)
self.to_speech_latent = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False)
self.to_speech_latent = ml.Linear(latent_dim, latent_dim, bias=False)
def get_grad_norm_parameter_groups(self):
return {

View File

@ -9,7 +9,7 @@ from models.arch_util import AttentionBlock
from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder
from trainer.networks import register_model
from utils.util import opt_get, checkpoint
import bitsandbytes as bnb
import torch_intermediary as ml
def exists(val):
@ -180,7 +180,7 @@ class ConvFormatEmbedding(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
# nn.Embedding
self.emb = bnb.nn.StableEmbedding(*args, **kwargs)
self.emb = ml.Embedding(*args, **kwargs)
def forward(self, x):
y = self.emb(x)
@ -205,8 +205,8 @@ class ContrastiveAudio(nn.Module):
self.emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim // 2, kernel_size=5, stride=2, padding=2),
nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1))
self.transformer = CollapsingTransformer(model_dim, model_dim, transformer_heads, dropout, encoder_depth, mask_percent)
self.to_latent = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False)
self.to_latent2 = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False)
self.to_latent = ml.Linear(latent_dim, latent_dim, bias=False)
self.to_latent2 = ml.Linear(latent_dim, latent_dim, bias=False)
self.to_latent2.weight.data = self.to_latent.weight.data
self.to_latent2.weight.DO_NOT_TRAIN = True

View File

@ -10,7 +10,7 @@ from models.arch_util import AttentionBlock
from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder
from trainer.networks import register_model
from utils.util import opt_get, checkpoint
import bitsandbytes as bnb
import torch_intermediary as ml
def exists(val):
@ -60,7 +60,7 @@ class ConvFormatEmbedding(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
# nn.Embedding
self.emb = bnb.nn.StableEmbedding(*args, **kwargs)
self.emb = ml.Embedding(*args, **kwargs)
def forward(self, x):
y = self.emb(x)
@ -88,14 +88,14 @@ class CVVP(nn.Module):
self.cond_emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim//2, kernel_size=5, stride=2, padding=2),
nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1))
self.conditioning_transformer = CollapsingTransformer(model_dim, model_dim, transformer_heads, dropout, conditioning_enc_depth, cond_mask_percentage)
self.to_conditioning_latent = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False)
self.to_conditioning_latent = ml.Linear(latent_dim, latent_dim, bias=False)
if mel_codes is None:
self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2)
else:
self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim)
self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage)
self.to_speech_latent = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False)
self.to_speech_latent = ml.Linear(latent_dim, latent_dim, bias=False)
def get_grad_norm_parameter_groups(self):
return {

View File

@ -7,7 +7,7 @@ from torch import einsum
from models.lucidrains.dalle.transformer import Transformer
from trainer.networks import register_model
from utils.util import opt_get
import bitsandbytes as bnb
import torch_intermediary as ml
def exists(val):
@ -47,19 +47,19 @@ class MelTextCLIP(nn.Module):
):
super().__init__()
# nn.Embedding
self.text_emb = bnb.nn.StableEmbedding(num_text_tokens, dim_text)
self.text_emb = ml.Embedding(num_text_tokens, dim_text)
# nn.Embedding
self.text_pos_emb = bnb.nn.StableEmbedding(text_seq_len, dim_text)
self.text_pos_emb = ml.Embedding(text_seq_len, dim_text)
self.text_transformer = Transformer(causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth,
heads=text_heads, rotary_emb=False)
self.to_text_latent = bnb.nn.Linear8bitLt(dim_text, dim_latent, bias=False)
self.to_text_latent = ml.Linear(dim_text, dim_latent, bias=False)
self.speech_enc = nn.Conv1d(80, dim_speech, kernel_size=3, padding=1)
# nn.Embedding
self.speech_pos_emb = bnb.nn.StableEmbedding(num_speech_tokens, dim_speech)
self.speech_pos_emb = ml.Embedding(num_speech_tokens, dim_speech)
self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech,
depth=speech_enc_depth, heads=speech_heads, rotary_emb=False)
self.to_speech_latent = bnb.nn.Linear8bitLt(dim_speech, dim_latent, bias=False)
self.to_speech_latent = ml.Linear(dim_speech, dim_latent, bias=False)
self.temperature = nn.Parameter(torch.tensor(1.))
self.text_mask_percentage = text_mask_percentage

View File

@ -7,7 +7,7 @@ from models.audio.tts.unified_voice2 import ConditioningEncoder
from models.lucidrains.dalle.transformer import Transformer
from trainer.networks import register_model
from utils.util import opt_get
import bitsandbytes as bnb
import torch_intermediary as ml
def exists(val):
@ -46,7 +46,7 @@ class VoiceCondCLIP(nn.Module):
self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech,
depth=speech_enc_depth, heads=speech_heads, rotary_emb=False)
self.to_speech_latent = bnb.nn.Linear8bitLt(dim_speech, dim_latent, bias=False)
self.to_speech_latent = ml.Linear(dim_speech, dim_latent, bias=False)
self.temperature = nn.Parameter(torch.tensor(1.))
self.voice_mask_percentage = voice_mask_percentage

View File

@ -11,7 +11,7 @@ from models.audio.tts.unet_diffusion_tts7 import CheckpointedXTransformerEncoder
from models.lucidrains.dalle.transformer import Transformer
from trainer.networks import register_model
from utils.util import opt_get
import bitsandbytes as bnb
import torch_intermediary as ml
def exists(val):
@ -55,12 +55,12 @@ class VoiceCLIP(nn.Module):
):
super().__init__()
# nn.Embedding
self.text_emb = bnb.nn.StableEmbedding(num_text_tokens, dim_text)
self.to_text_latent = bnb.nn.Linear8bitLt(dim_text, dim_latent, bias=False)
self.text_emb = ml.Embedding(num_text_tokens, dim_text)
self.to_text_latent = ml.Linear(dim_text, dim_latent, bias=False)
# nn.Embedding
self.speech_emb = bnb.nn.StableEmbedding(num_speech_tokens, dim_speech)
self.to_speech_latent = bnb.nn.Linear8bitLt(dim_speech, dim_latent, bias=False)
self.speech_emb = ml.Embedding(num_speech_tokens, dim_speech)
self.to_speech_latent = ml.Linear(dim_speech, dim_latent, bias=False)
if use_xformers:
self.text_transformer = CheckpointedXTransformerEncoder(
@ -109,9 +109,9 @@ class VoiceCLIP(nn.Module):
self.distributed_collect = distributed_collect
if not use_xformers:
# nn.Embedding
self.text_pos_emb = bnb.nn.StableEmbedding(text_seq_len, dim_text)
self.text_pos_emb = ml.Embedding(text_seq_len, dim_text)
# nn.Embedding
self.speech_pos_emb = bnb.nn.StableEmbedding(num_speech_tokens, dim_speech)
self.speech_pos_emb = ml.Embedding(num_speech_tokens, dim_speech)
def embed_text(self, text):
text_mask = torch.ones_like(text.float()).bool()

View File

@ -6,7 +6,7 @@ import math
import torch as th
import torch.nn as nn
import bitsandbytes as bnb
import torch_intermediary as ml
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
@ -37,7 +37,7 @@ def linear(*args, **kwargs):
"""
Create a linear module.
"""
return bnb.nn.Linear8bitLt(*args, **kwargs)
return ml.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):

View File

@ -6,7 +6,7 @@ from models.arch_util import ConvGnLelu, default_init_weights, make_layer
from models.diffusion.nn import timestep_embedding
from trainer.networks import register_model
from utils.util import checkpoint
import bitsandbytes as bnb
import torch_intermediary as ml
# Conditionally uses torch's checkpoint functionality if it is enabled in the opt file.
@ -29,7 +29,7 @@ class ResidualDenseBlock(nn.Module):
self.first_conv = ConvGnLelu(mid_channels, mid_channels, activation=True, norm=False, bias=True)
self.emb_layers = nn.Sequential(
nn.SiLU(),
bnb.nn.Linear8bitLt(
ml.Linear(
mid_channels*4,
mid_channels,
),
@ -144,9 +144,9 @@ class RRDBNet(nn.Module):
# Guided diffusion uses a time embedding.
time_embed_dim = mid_channels * 4
self.time_embed = nn.Sequential(
bnb.nn.Linear8bitLt(mid_channels, time_embed_dim),
ml.Linear(mid_channels, time_embed_dim),
nn.SiLU(),
bnb.nn.Linear8bitLt(time_embed_dim, time_embed_dim),
ml.Linear(time_embed_dim, time_embed_dim),
)
self.body = make_layer(

View File

@ -20,7 +20,7 @@ from models.diffusion.nn import (
)
from trainer.networks import register_model
from utils.util import checkpoint
import bitsandbytes as bnb
import torch_intermediary as ml
class AttentionPool2d(nn.Module):
@ -517,7 +517,7 @@ class UNetModel(nn.Module):
if self.num_classes is not None:
# nn.Embedding
self.label_emb = bnb.nn.StableEmbedding(num_classes, time_embed_dim)
self.label_emb = ml.Embedding(num_classes, time_embed_dim)
self.use_raw_y_as_embedding = use_raw_y_as_embedding
assert not ((self.num_classes is not None) and use_raw_y_as_embedding) # These are mutually-exclusive.
@ -869,16 +869,16 @@ class EncoderUNetModel(nn.Module):
)
elif pool == "spatial":
self.out = nn.Sequential(
bnb.nn.Linear8bitLt(self._feature_size, 2048),
ml.Linear(self._feature_size, 2048),
nn.ReLU(),
bnb.nn.Linear8bitLt(2048, self.out_channels),
ml.Linear(2048, self.out_channels),
)
elif pool == "spatial_v2":
self.out = nn.Sequential(
bnb.nn.Linear8bitLt(self._feature_size, 2048),
ml.Linear(self._feature_size, 2048),
normalization(2048),
nn.SiLU(),
bnb.nn.Linear8bitLt(2048, self.out_channels),
ml.Linear(2048, self.out_channels),
)
else:
raise NotImplementedError(f"Unexpected {pool} pooling")

View File

@ -26,7 +26,7 @@ from models.diffusion.nn import (
)
from trainer.networks import register_model
from utils.util import checkpoint
import bitsandbytes as bnb
import torch_intermediary as ml
class AttentionPool2d(nn.Module):
@ -478,7 +478,7 @@ class UNetModel(nn.Module):
if self.num_classes is not None:
# nn.Embedding
self.label_emb = bnb.nn.StableEmbedding(num_classes, time_embed_dim)
self.label_emb = ml.Embedding(num_classes, time_embed_dim)
self.input_blocks = nn.ModuleList(
[
@ -738,7 +738,7 @@ class ResNetEncoder(nn.Module):
dilate=replace_stride_with_dilation[2])
f=512
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = bnb.nn.Linear8bitLt(f * block.expansion, output_dim)
self.fc = ml.Linear(f * block.expansion, output_dim)
for m in self.modules():
if isinstance(m, nn.Conv2d):

View File

@ -5,7 +5,7 @@ import torch.nn.functional as F
from trainer.networks import register_model
from utils.util import checkpoint, opt_get
import bitsandbytes as bnb
import torch_intermediary as ml
class Discriminator_VGG_128(nn.Module):
@ -47,8 +47,8 @@ class Discriminator_VGG_128(nn.Module):
input_img_factor = input_img_factor // 2
final_nf = nf * 16
self.linear1 = bnb.nn.Linear8bitLt(final_nf * 4 * input_img_factor * 4 * input_img_factor, 100)
self.linear2 = bnb.nn.Linear8bitLt(100, 1)
self.linear1 = ml.Linear(final_nf * 4 * input_img_factor * 4 * input_img_factor, 100)
self.linear2 = ml.Linear(100, 1)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
@ -130,8 +130,8 @@ class Discriminator_VGG_128_GN(nn.Module):
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.linear1 = bnb.nn.Linear8bitLt(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 100)
self.linear2 = bnb.nn.Linear8bitLt(100, 1)
self.linear1 = ml.Linear(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 100)
self.linear2 = ml.Linear(100, 1)
def compute_body(self, x):
fea = self.lrelu(self.conv0_0(x))
@ -220,8 +220,8 @@ class DiscriminatorVGG448GN(nn.Module):
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
final_nf = nf * 8
self.linear1 = bnb.nn.Linear8bitLt(int(final_nf * 7 * 7), 100)
self.linear2 = bnb.nn.Linear8bitLt(100, 1)
self.linear1 = ml.Linear(int(final_nf * 7 * 7), 100)
self.linear2 = ml.Linear(100, 1)
# Assign all new heads to the new param group.2
for m in [self.convn1_0, self.convn1_1, self.bnn1_1, self.conv0_0_new, self.bn0_0]:

View File

@ -2,7 +2,7 @@ import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import bitsandbytes as bnb
import torch_intermediary as ml
def initialize_weights(net_l, scale=1):
@ -15,7 +15,7 @@ def initialize_weights(net_l, scale=1):
m.weight.data *= scale # for residual block
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, bnb.nn.Linear8bitLt):
elif isinstance(m, ml.Linear):
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale
if m.bias is not None:

View File

@ -28,7 +28,7 @@ except:
APEX_AVAILABLE = False
assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.'
import bitsandbytes as bnb
import torch_intermediary as ml
num_cores = multiprocessing.cpu_count()
@ -352,7 +352,7 @@ class RGBBlock(nn.Module):
def __init__(self, latent_dim, input_channel, upsample, rgba=False):
super().__init__()
self.input_channel = input_channel
self.to_style = bnb.nn.Linear8bitLt(latent_dim, input_channel)
self.to_style = ml.Linear(latent_dim, input_channel)
out_filters = 3 if not rgba else 4
self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False)
@ -490,16 +490,16 @@ class GeneratorBlockWithStructure(nn.Module):
# Uses stylegan1 style blocks for injecting structural latent.
self.conv0 = EqualConv2d(input_channels, filters, 3, padding=1)
self.to_noise0 = bnb.nn.Linear8bitLt(1, filters)
self.to_noise0 = ml.Linear(1, filters)
self.noise0 = equal_lr(NoiseInjection(filters))
self.adain0 = AdaptiveInstanceNorm(filters, latent_dim)
self.to_style1 = bnb.nn.Linear8bitLt(latent_dim, filters)
self.to_noise1 = bnb.nn.Linear8bitLt(1, filters)
self.to_style1 = ml.Linear(latent_dim, filters)
self.to_noise1 = ml.Linear(1, filters)
self.conv1 = Conv2DMod(filters, filters, 3)
self.to_style2 = bnb.nn.Linear8bitLt(latent_dim, filters)
self.to_noise2 = bnb.nn.Linear8bitLt(1, filters)
self.to_style2 = ml.Linear(latent_dim, filters)
self.to_noise2 = ml.Linear(1, filters)
self.conv2 = Conv2DMod(filters, filters, 3)
self.activation = leaky_relu()
@ -541,12 +541,12 @@ class GeneratorBlock(nn.Module):
self.structure_conv = nn.Conv2d(3, input_channels, 3, padding=1)
input_channels = input_channels * 2
self.to_style1 = bnb.nn.Linear8bitLt(latent_dim, input_channels)
self.to_noise1 = bnb.nn.Linear8bitLt(1, filters)
self.to_style1 = ml.Linear(latent_dim, input_channels)
self.to_noise1 = ml.Linear(1, filters)
self.conv1 = Conv2DMod(input_channels, filters, 3)
self.to_style2 = bnb.nn.Linear8bitLt(latent_dim, filters)
self.to_noise2 = bnb.nn.Linear8bitLt(1, filters)
self.to_style2 = ml.Linear(latent_dim, filters)
self.to_noise2 = ml.Linear(1, filters)
self.conv2 = Conv2DMod(filters, filters, 3)
self.activation = leaky_relu()
@ -725,7 +725,7 @@ class StyleGan2GeneratorWithLatent(nn.Module):
def _init_weights(self):
for m in self.modules():
if type(m) in {nn.Conv2d, bnb.nn.Linear8bitLt} and hasattr(m, 'weight'):
if type(m) in {nn.Conv2d, ml.Linear} and hasattr(m, 'weight'):
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
for block in self.gen.blocks:
@ -805,7 +805,7 @@ class StyleGan2Discriminator(nn.Module):
self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1)
self.flatten = Flatten()
self.to_logit = bnb.nn.Linear8bitLt(latent_dim, 1)
self.to_logit = ml.Linear(latent_dim, 1)
self._init_weights()
@ -837,7 +837,7 @@ class StyleGan2Discriminator(nn.Module):
def _init_weights(self):
for m in self.modules():
if type(m) in {nn.Conv2d, bnb.nn.Linear8bitLt}:
if type(m) in {nn.Conv2d, ml.Linear}:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')

View File

@ -12,7 +12,7 @@ from torch import nn
from data.images.byol_attachment import RandomApply
from trainer.networks import register_model, create_model
from utils.util import checkpoint, opt_get
import bitsandbytes as bnb
import torch_intermediary as ml
def default(val, def_val):
@ -79,10 +79,10 @@ class MLP(nn.Module):
def __init__(self, dim, projection_size, hidden_size=4096):
super().__init__()
self.net = nn.Sequential(
bnb.nn.Linear8bitLt(dim, hidden_size),
ml.Linear(dim, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(inplace=True),
bnb.nn.Linear8bitLt(hidden_size, projection_size)
ml.Linear(hidden_size, projection_size)
)
def forward(self, x):
@ -104,10 +104,10 @@ class StructuralMLP(nn.Module):
nn.BatchNorm2d(c),
nn.ReLU(inplace=True),
nn.Flatten(),
bnb.nn.Linear8bitLt(flattened_dim, hidden_size),
ml.Linear(flattened_dim, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(inplace=True),
bnb.nn.Linear8bitLt(hidden_size, projection_size)
ml.Linear(hidden_size, projection_size)
)
def forward(self, x):

View File

@ -1,7 +1,7 @@
import torch
import torch.nn as nn
import numpy as np
import bitsandbytes as bnb
import torch_intermediary as ml
__all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152']
@ -109,8 +109,8 @@ class FixupResNet(nn.Module):
self.layer4 = self._make_layer(block, num_filters*8, layers[3], stride=2)
self.bias2 = nn.Parameter(torch.zeros(1))
reduced_img_sz = int(input_img_size / 32)
self.fc1 = bnb.nn.Linear8bitLt(num_filters * 8 * reduced_img_sz * reduced_img_sz, 100)
self.fc2 = bnb.nn.Linear8bitLt(100, num_classes)
self.fc1 = ml.Linear(num_filters * 8 * reduced_img_sz * reduced_img_sz, 100)
self.fc2 = ml.Linear(100, num_classes)
for m in self.modules():
if isinstance(m, FixupBasicBlock):
@ -125,7 +125,7 @@ class FixupResNet(nn.Module):
if m.downsample is not None:
nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:]))))
'''
elif isinstance(m, bnb.nn.Linear8bitLt):
elif isinstance(m, ml.Linear):
nn.init.constant_(m.weight, 0)
nn.init.constant_(m.bias, 0)'''

View File

@ -5,7 +5,7 @@ import torch.nn.functional as F
from models.arch_util import ResBlock
from models.lucidrains.x_transformers import Encoder
from trainer.networks import register_model
import bitsandbytes as bnb
import torch_intermediary as ml
class VitLatent(nn.Module):
@ -32,10 +32,10 @@ class VitLatent(nn.Module):
do_checkpointing=True
)
self.mlp = nn.Sequential(bnb.nn.Linear8bitLt(hidden_dim, hidden_dim*2),
self.mlp = nn.Sequential(ml.Linear(hidden_dim, hidden_dim*2),
nn.BatchNorm1d(hidden_dim*2),
nn.ReLU(inplace=True),
bnb.nn.Linear8bitLt(hidden_dim*2, hidden_dim))
ml.Linear(hidden_dim*2, hidden_dim))
def provide_ema(self, ema):
self.ema = ema

View File

@ -7,7 +7,7 @@ import torch.nn.functional as F
from einops import rearrange, repeat
from rotary_embedding_torch import apply_rotary_emb
import bitsandbytes as bnb
import torch_intermediary as ml
# helpers
@ -48,9 +48,9 @@ class Attention(nn.Module):
self.stable = stable
self.causal = causal
self.to_qkv = bnb.nn.Linear8bitLt(dim, inner_dim * 3, bias = False)
self.to_qkv = ml.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
bnb.nn.Linear8bitLt(inner_dim, dim),
ml.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
@ -103,10 +103,10 @@ class SparseConvCausalAttention(nn.Module):
self.stable = stable
self.to_qkv = bnb.nn.Linear8bitLt(dim, inner_dim * 3, bias = False)
self.to_qkv = ml.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
bnb.nn.Linear8bitLt(inner_dim, dim),
ml.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
@ -223,10 +223,10 @@ class SparseAxialCausalAttention(nn.Module):
self.stable = stable
self.to_qkv = bnb.nn.Linear8bitLt(dim, inner_dim * 3, bias = False)
self.to_qkv = ml.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
bnb.nn.Linear8bitLt(inner_dim, dim),
ml.Linear(inner_dim, dim),
nn.Dropout(dropout)
)

View File

@ -11,7 +11,7 @@ from models.lucidrains.dalle.attention import Attention, SparseAttention, Sparse
from rotary_embedding_torch import RotaryEmbedding, broadcat
from g_mlp_pytorch import gMLPBlock
import bitsandbytes as bnb
import torch_intermediary as ml
# helpers
@ -79,10 +79,10 @@ class FeedForward(nn.Module):
def __init__(self, dim, dropout = 0., mult = 4.):
super().__init__()
self.net = nn.Sequential(
bnb.nn.Linear8bitLt(dim, dim * mult * 2),
ml.Linear(dim, dim * mult * 2),
GEGLU(),
nn.Dropout(dropout),
bnb.nn.Linear8bitLt(dim * mult, dim)
ml.Linear(dim * mult, dim)
)
def forward(self, x):

View File

@ -21,7 +21,7 @@ try:
APEX_AVAILABLE = True
except:
APEX_AVAILABLE = False
import bitsandbytes as bnb
import torch_intermediary as ml
# helpers
@ -357,10 +357,10 @@ class FeedForward(nn.Module):
activation = default(activation, nn.GELU)
self.glu = glu
self.w1 = bnb.nn.Linear8bitLt(dim, dim * mult * (2 if glu else 1))
self.w1 = ml.Linear(dim, dim * mult * (2 if glu else 1))
self.act = activation()
self.dropout = nn.Dropout(dropout)
self.w2 = bnb.nn.Linear8bitLt(dim * mult, dim)
self.w2 = ml.Linear(dim * mult, dim)
def forward(self, x, **kwargs):
if not self.glu:
@ -402,10 +402,10 @@ class Attention(nn.Module):
self.global_heads = heads - local_heads
self.local_attn = LocalAttention(window_size = local_window_size, causal = causal, autopad = True, dropout = dropout, look_forward = int(not causal), rel_pos_emb_config = (dim_head, local_heads)) if local_heads > 0 else None
self.to_q = bnb.nn.Linear8bitLt(dim, inner_dim, bias = qkv_bias)
self.to_k = bnb.nn.Linear8bitLt(dim, inner_dim, bias = qkv_bias)
self.to_v = bnb.nn.Linear8bitLt(dim, inner_dim, bias = qkv_bias)
self.to_out = bnb.nn.Linear8bitLt(inner_dim, dim, bias = attn_out_bias)
self.to_q = ml.Linear(dim, inner_dim, bias = qkv_bias)
self.to_k = ml.Linear(dim, inner_dim, bias = qkv_bias)
self.to_v = ml.Linear(dim, inner_dim, bias = qkv_bias)
self.to_out = ml.Linear(inner_dim, dim, bias = attn_out_bias)
self.dropout = nn.Dropout(dropout)
def forward(self, x, pos_emb = None, context = None, mask = None, context_mask = None, **kwargs):
@ -460,7 +460,7 @@ class AbsolutePositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len):
super().__init__()
# nn.Embedding
self.emb = bnb.nn.StableEmbedding(max_seq_len, dim)
self.emb = ml.Embedding(max_seq_len, dim)
def forward(self, x):
t = torch.arange(x.shape[1], device=x.device)
@ -622,7 +622,7 @@ class PerformerLM(nn.Module):
self.max_seq_len = max_seq_len
# nn.Embedding
self.token_emb = bnb.nn.StableEmbedding(num_tokens, dim)
self.token_emb = ml.Embedding(num_tokens, dim)
if rotary_position_emb:
self.pos_emb = FixedPositionalEmbedding(dim, max_seq_len)
@ -639,7 +639,7 @@ class PerformerLM(nn.Module):
self.performer = Performer(dim, depth, heads, dim_head, local_attn_heads, local_window_size, causal, ff_mult, nb_features, feature_redraw_interval, reversible, ff_chunks, generalized_attention, kernel_fn, use_scalenorm, use_rezero, ff_glu, ff_dropout, attn_dropout, cross_attend, no_projection, auto_check_redraw, qkv_bias, attn_out_bias, shift_tokens)
self.norm = nn.LayerNorm(dim)
self.to_out = bnb.nn.Linear8bitLt(dim, num_tokens) if not tie_embed else None
self.to_out = ml.Linear(dim, num_tokens) if not tie_embed else None
def check_redraw_projections(self):
self.performer.check_redraw_projections()

View File

@ -8,7 +8,7 @@ from torch.cuda.amp import autocast
from einops import rearrange, repeat
from contextlib import contextmanager
import bitsandbytes as bnb
import torch_intermediary as ml
def par(t, nm):
@ -356,9 +356,9 @@ class VectorQuantize(nn.Module):
codebook_dim = default(codebook_dim, dim)
requires_projection = codebook_dim != dim
self.project_in = bnb.nn.Linear8bitLt(dim, codebook_dim) if requires_projection \
self.project_in = ml.Linear(dim, codebook_dim) if requires_projection \
else nn.Identity()
self.project_out = bnb.nn.Linear8bitLt(codebook_dim, dim) if requires_projection \
self.project_out = ml.Linear(codebook_dim, dim) if requires_projection \
else nn.Identity()
self.eps = eps

View File

@ -11,7 +11,7 @@ from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
from torch.utils.checkpoint import checkpoint
import bitsandbytes as bnb
import torch_intermediary as ml
DEFAULT_DIM_HEAD = 64
@ -127,7 +127,7 @@ class AbsolutePositionalEmbedding(nn.Module):
super().__init__()
self.scale = dim ** -0.5
# nn.Embedding
self.emb = bnb.nn.StableEmbedding(max_seq_len, dim)
self.emb = ml.Embedding(max_seq_len, dim)
def forward(self, x):
n = torch.arange(x.shape[1], device=x.device)
@ -157,7 +157,7 @@ class RelativePositionBias(nn.Module):
self.num_buckets = num_buckets
self.max_distance = max_distance
# nn.Embedding
self.relative_attention_bias = bnb.nn.StableEmbedding(num_buckets, heads)
self.relative_attention_bias = ml.Embedding(num_buckets, heads)
@staticmethod
def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
@ -363,7 +363,7 @@ class RMSScaleShiftNorm(nn.Module):
self.cdim = 1
self.pdim = -1
else:
self.scale_shift_process = bnb.nn.Linear8bitLt(embed_dim, dim * 2, bias=bias)
self.scale_shift_process = ml.Linear(embed_dim, dim * 2, bias=bias)
self.cdim = -1
self.pdim = 1
@ -450,7 +450,7 @@ class GLU(nn.Module):
def __init__(self, dim_in, dim_out, activation):
super().__init__()
self.act = activation
self.proj = bnb.nn.Linear8bitLt(dim_in, dim_out * 2)
self.proj = ml.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
@ -475,7 +475,7 @@ class FeedForward(nn.Module):
activation = ReluSquared() if relu_squared else nn.GELU()
project_in = nn.Sequential(
bnb.nn.Linear8bitLt(dim, inner_dim),
ml.Linear(dim, inner_dim),
activation
) if not glu else GLU(dim, inner_dim, activation)
@ -483,7 +483,7 @@ class FeedForward(nn.Module):
project_in,
nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(),
nn.Dropout(dropout),
bnb.nn.Linear8bitLt(inner_dim, dim_out)
ml.Linear(inner_dim, dim_out)
)
# init last linear layer to 0
@ -538,16 +538,16 @@ class Attention(nn.Module):
qk_dim = int(collab_compression * qk_dim)
self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim))
self.to_q = bnb.nn.Linear8bitLt(dim, qk_dim, bias=False)
self.to_k = bnb.nn.Linear8bitLt(dim, qk_dim, bias=False)
self.to_v = bnb.nn.Linear8bitLt(dim, v_dim, bias=False)
self.to_q = ml.Linear(dim, qk_dim, bias=False)
self.to_k = ml.Linear(dim, qk_dim, bias=False)
self.to_v = ml.Linear(dim, v_dim, bias=False)
self.dropout = nn.Dropout(dropout)
# add GLU gating for aggregated values, from alphafold2
self.to_v_gate = None
if gate_values:
self.to_v_gate = bnb.nn.Linear8bitLt(dim, v_dim)
self.to_v_gate = ml.Linear(dim, v_dim)
nn.init.constant_(self.to_v_gate.weight, 0)
nn.init.constant_(self.to_v_gate.bias, 1)
@ -584,7 +584,7 @@ class Attention(nn.Module):
# attention on attention
self.attn_on_attn = on_attn
out_dim = default(out_dim, dim)
self.to_out = nn.Sequential(bnb.nn.Linear8bitLt(v_dim, out_dim * 2), nn.GLU()) if on_attn else bnb.nn.Linear8bitLt(v_dim, out_dim)
self.to_out = nn.Sequential(ml.Linear(v_dim, out_dim * 2), nn.GLU()) if on_attn else ml.Linear(v_dim, out_dim)
self.rel_pos_bias = rel_pos_bias
if rel_pos_bias:
@ -1080,7 +1080,7 @@ class ViTransformerWrapper(nn.Module):
self.patch_size = patch_size
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_to_embedding = bnb.nn.Linear8bitLt(patch_dim, dim)
self.patch_to_embedding = ml.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
@ -1139,18 +1139,18 @@ class TransformerWrapper(nn.Module):
self.shift_mem_down = shift_mem_down
# nn.Embedding
self.token_emb = bnb.nn.StableEmbedding(num_tokens, emb_dim)
self.token_emb = ml.Embedding(num_tokens, emb_dim)
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
self.emb_dropout = nn.Dropout(emb_dropout)
self.project_emb = bnb.nn.Linear8bitLt(emb_dim, dim) if emb_dim != dim else nn.Identity()
self.project_emb = ml.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim)
self.init_()
self.to_logits = bnb.nn.Linear8bitLt(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
self.to_logits = ml.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
# memory tokens (like [cls]) from Memory Transformers paper
num_memory_tokens = default(num_memory_tokens, 0)
@ -1237,12 +1237,12 @@ class ContinuousTransformerWrapper(nn.Module):
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
self.emb_dropout = nn.Dropout(emb_dropout)
self.project_in = bnb.nn.Linear8bitLt(dim_in, dim) if exists(dim_in) else nn.Identity()
self.project_in = ml.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim)
self.project_out = bnb.nn.Linear8bitLt(dim, dim_out) if exists(dim_out) else nn.Identity()
self.project_out = ml.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
def forward(
self,

View File

@ -4,7 +4,7 @@ import torch.nn.functional as F
from torch import einsum
from utils.weight_scheduler import LinearDecayWeightScheduler
import bitsandbytes as bnb
import torch_intermediary as ml
class GumbelQuantizer(nn.Module):
@ -12,7 +12,7 @@ class GumbelQuantizer(nn.Module):
super().__init__()
self.to_logits = nn.Conv1d(inp_dim, num_tokens, 1)
# nn.Embedding
self.codebook = bnb.nn.StableEmbedding(num_tokens, codebook_dim)
self.codebook = ml.Embedding(num_tokens, codebook_dim)
self.straight_through = straight_through
self.temperature_scheduler = LinearDecayWeightScheduler(10, 5000, .9, 2000)
self.step = 0

View File

@ -4,7 +4,7 @@ import torch.nn.functional as F
from einops import rearrange, repeat
from models.arch_util import l2norm, sample_vectors, default, ema_inplace
import bitsandbytes as bnb
import torch_intermediary as ml
def kmeans(samples, num_clusters, num_iters = 10, use_cosine_sim = False):
@ -185,8 +185,8 @@ class VectorQuantize(nn.Module):
codebook_dim = default(codebook_dim, dim)
requires_projection = codebook_dim != dim
self.project_in = bnb.nn.Linear8bitLt(dim, codebook_dim) if requires_projection else nn.Identity()
self.project_out = bnb.nn.Linear8bitLt(codebook_dim, dim) if requires_projection else nn.Identity()
self.project_in = ml.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
self.project_out = ml.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
self.eps = eps

View File

@ -0,0 +1,41 @@
"""
from bitsandbytes.nn import Linear8bitLt as Linear
from bitsandbytes.nn import StableEmbedding as Embedding
from bitsandbytes.optim.adam import Adam8bit as Adam
from bitsandbytes.optim.adamw import AdamW8bit as AdamW
"""
"""
from torch.nn import Linear
from torch.nn import Embedding
from torch.optim.adam import Adam
from torch.optim.adamw import AdamW
"""
OVERRIDE_LINEAR = False
OVERRIDE_EMBEDDING = False
OVERRIDE_ADAM = True
OVERRIDE_ADAMW = True
USE_STABLE_EMBEDDING = True
if OVERRIDE_LINEAR:
from bitsandbytes.nn import Linear8bitLt as Linear
else:
from torch.nn import Linear
if OVERRIDE_EMBEDDING:
if USE_STABLE_EMBEDDING:
from bitsandbytes.nn import StableEmbedding as Embedding
else:
from bitsandbytes.nn import Embedding as Embedding
else:
from torch.nn import Embedding
if OVERRIDE_ADAM:
from bitsandbytes.optim.adam import Adam8bit as Adam
else:
from torch.optim.adam import Adam
if OVERRIDE_ADAMW:
from bitsandbytes.optim.adamw import AdamW8bit as AdamW
else:
from torch.optim.adamw import AdamW

View File

@ -21,7 +21,7 @@ import torchvision.utils as utils
from utils.loss_accumulator import LossAccumulator, InfStorageLossAccumulator
from utils.util import opt_get, denormalize
import bitsandbytes as bnb
import torch_intermediary as ml
logger = logging.getLogger('base')
@ -338,7 +338,7 @@ class ExtensibleTrainer(BaseModel):
for net in self.networks.values():
for mod in net.modules():
fan_in = -1
if isinstance(mod, bnb.nn.Linear8bitLt):
if isinstance(mod, ml.Linear):
fan_in = mod.weight.data.shape[1]
elif isinstance(mod, nn.Conv1d):
fan_in = mod.weight.data.shape[0]

View File

@ -7,7 +7,7 @@ import torch.nn as nn
import trainer.networks as networks
import trainer.lr_scheduler as lr_scheduler
from .base_model import BaseModel
import bitsandbytes as bnb
import torch_intermediary as ml
logger = logging.getLogger('base')
@ -43,7 +43,7 @@ class FeatureModel(BaseModel):
if self.rank <= 0:
logger.warning('Params [{:s}] will not optimize.'.format(k))
# torch.optim.Adam
self.optimizer_G = bnb.optim.Adam8bit(optim_params, lr=train_opt['lr_G'],
self.optimizer_G = ml.Adam(optim_params, lr=train_opt['lr_G'],
weight_decay=wd_G,
betas=(train_opt['beta1_G'], train_opt['beta2_G']))
self.optimizers.append(self.optimizer_G)

View File

@ -3,7 +3,7 @@ from collections import Counter
from collections import defaultdict
import torch
from torch.optim.lr_scheduler import _LRScheduler
import bitsandbytes as bnb
import torch_intermediary as ml
from utils.util import opt_get
@ -137,7 +137,7 @@ class CosineAnnealingLR_Restart(_LRScheduler):
if __name__ == "__main__":
#torch.optim.Adam
optimizer = bnb.optim.Adam8bit([torch.zeros(3, 64, 3, 3)], lr=1e-4, weight_decay=0,
optimizer = ml.Adam([torch.zeros(3, 64, 3, 3)], lr=1e-4, weight_decay=0,
betas=(0.9, 0.99))
##############################
# MultiStepLR_Restart

View File

@ -12,7 +12,7 @@ from utils.util import recursively_detach, opt_get, clip_grad_norm
logger = logging.getLogger('base')
import bitsandbytes as bnb
import torch_intermediary as ml
# Defines the expected API for a single training step
class ConfigurableStep(Module):
@ -84,7 +84,7 @@ class ConfigurableStep(Module):
norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d,
nn.BatchNorm3d, nn.InstanceNorm3d, nn.GroupNorm, nn.LayerNorm)
# nn.Embedding
emb_modules = (bnb.nn.StableEmbedding, nn.EmbeddingBag)
emb_modules = (ml.Embedding, nn.EmbeddingBag)
param_names_notweights = set()
all_param_names = set()
param_map = {}
@ -126,7 +126,7 @@ class ConfigurableStep(Module):
{ 'params': params_notweights, 'weight_decay': 0 }
]
# torch.optim.AdamW
opt = bnb.optim.AdamW8bit(groups, lr=opt_config['lr'],
opt = ml.AdamW(groups, lr=opt_config['lr'],
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
opt._group_names = [params_names_weights, params_names_notweights]
@ -145,7 +145,7 @@ class ConfigurableStep(Module):
# parameters and just use a normal AdamW implementation. In a large network, these weights will normally
# be a tiny fraction of the total weights.
# torch.optim.AdamW
opt_unweighted = bnb.optim.AdamW8bit(params_notweights, lr=opt_config['lr'], weight_decay=0,
opt_unweighted = ml.AdamW(params_notweights, lr=opt_config['lr'], weight_decay=0,
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
opt_unweighted._config = opt_config
opt_unweighted._config['network'] = net_name
@ -153,7 +153,7 @@ class ConfigurableStep(Module):
self.optimizers.append(opt_unweighted)
# torch.optim.AdamW
opt = ZeroRedundancyOptimizer(params_weights, optimizer_class=bnb.optim.AdamW8bit, lr=opt_config['lr'],
opt = ZeroRedundancyOptimizer(params_weights, optimizer_class=ml.AdamW, lr=opt_config['lr'],
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
opt.param_groups[0]['initial_lr'] = opt_config['lr']
@ -168,7 +168,7 @@ class ConfigurableStep(Module):
elif self.step_opt['optimizer'] == 'lamb':
from trainer.optimizers.lamb import Lamb
# torch.optim.AdamW
opt_unweighted = bnb.optim.AdamW8bit(params_notweights, lr=opt_config['lr'], weight_decay=0,
opt_unweighted = ml.AdamW(params_notweights, lr=opt_config['lr'], weight_decay=0,
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
opt_unweighted._config = opt_config
opt_unweighted._config['network'] = net_name