I sucked off the hyptothetical wizard again, just using BNB's ADAM optimizer nets HUGE savings, but I don't know the output costs, will need to test

This commit is contained in:
mrq 2023-02-23 02:42:17 +00:00
parent 01c0941a40
commit 6676c89c0e
68 changed files with 317 additions and 276 deletions

View File

@ -9,7 +9,7 @@ import torch.nn.utils.spectral_norm as SpectralNorm
from math import sqrt from math import sqrt
from utils.util import checkpoint from utils.util import checkpoint
import bitsandbytes as bnb import torch_intermediary as ml
def exists(val): def exists(val):
@ -74,7 +74,7 @@ def initialize_weights(net_l, scale=1):
m.weight.data *= scale # for residual block m.weight.data *= scale # for residual block
if m.bias is not None: if m.bias is not None:
m.bias.data.zero_() m.bias.data.zero_()
elif isinstance(m, bnb.nn.Linear8bitLt): elif isinstance(m, ml.Linear):
init.kaiming_normal_(m.weight, a=0, mode='fan_in') init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale m.weight.data *= scale
if m.bias is not None: if m.bias is not None:
@ -109,7 +109,7 @@ def default_init_weights(module, scale=1):
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
kaiming_init(m, a=0, mode='fan_in', bias=0) kaiming_init(m, a=0, mode='fan_in', bias=0)
m.weight.data *= scale m.weight.data *= scale
elif isinstance(m, bnb.nn.Linear8bitLt): elif isinstance(m, ml.Linear):
kaiming_init(m, a=0, mode='fan_in', bias=0) kaiming_init(m, a=0, mode='fan_in', bias=0)
m.weight.data *= scale m.weight.data *= scale
@ -142,7 +142,7 @@ def linear(*args, **kwargs):
""" """
Create a linear module. Create a linear module.
""" """
return bnb.nn.Linear8bitLt(*args, **kwargs) return ml.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs): def avg_pool_nd(dims, *args, **kwargs):

View File

@ -9,7 +9,7 @@ from data.audio.unsupervised_audio_dataset import load_audio
from models.audio.tts.tacotron2.text import sequence_to_text from models.audio.tts.tacotron2.text import sequence_to_text
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import opt_get from utils.util import opt_get
import bitsandbytes as bnb import torch_intermediary as ml
def only_letters(string): def only_letters(string):
@ -52,7 +52,7 @@ class Wav2VecWrapper(nn.Module):
self.w2v = Wav2Vec2ForCTC.from_pretrained(basis_model) self.w2v = Wav2Vec2ForCTC.from_pretrained(basis_model)
# Perform some surgery to get the model we actually want. # Perform some surgery to get the model we actually want.
self.w2v.wav2vec2.encoder.gradient_checkpointing = checkpointing_enabled self.w2v.wav2vec2.encoder.gradient_checkpointing = checkpointing_enabled
self.w2v.lm_head = bnb.nn.Linear8bitLt(self.w2v.config.hidden_size, vocab_size) self.w2v.lm_head = ml.Linear(self.w2v.config.hidden_size, vocab_size)
self.w2v.config.vocab_size = vocab_size self.w2v.config.vocab_size = vocab_size
self.w2v.config.pad_token_id = 0 self.w2v.config.pad_token_id = 0
self.w2v.config.ctc_loss_reduction = 'sum' self.w2v.config.ctc_loss_reduction = 'sum'

View File

@ -5,7 +5,7 @@ import torch.nn as nn
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import opt_get from utils.util import opt_get
from typing import Type, Any, Callable, Union, List, Optional from typing import Type, Any, Callable, Union, List, Optional
import bitsandbytes as bnb import torch_intermediary as ml
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
@ -173,7 +173,7 @@ class ResNet(nn.Module):
self.layer4 = self._make_layer(block, 512, layers[3], stride=4, self.layer4 = self._make_layer(block, 512, layers[3], stride=4,
dilate=replace_stride_with_dilation[2]) dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool1d(1) self.avgpool = nn.AdaptiveAvgPool1d(1)
self.fc = bnb.nn.Linear8bitLt(512 * block.expansion, num_classes) self.fc = ml.Linear(512 * block.expansion, num_classes)
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv1d): if isinstance(m, nn.Conv1d):

View File

@ -15,14 +15,14 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
from models.arch_util import ResBlock from models.arch_util import ResBlock
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import checkpoint from utils.util import checkpoint
import bitsandbytes as bnb import torch_intermediary as ml
class Mel2Vec2FeatureProjection(nn.Module): class Mel2Vec2FeatureProjection(nn.Module):
def __init__(self, inner_dim, dropout): def __init__(self, inner_dim, dropout):
super().__init__() super().__init__()
self.layer_norm = nn.LayerNorm(inner_dim, eps=1e-5) self.layer_norm = nn.LayerNorm(inner_dim, eps=1e-5)
self.projection = bnb.nn.Linear8bitLt(inner_dim, inner_dim) self.projection = ml.Linear(inner_dim, inner_dim)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
def forward(self, hidden_states): def forward(self, hidden_states):
@ -59,10 +59,10 @@ class Wav2Vec2Attention(nn.Module):
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder self.is_decoder = is_decoder
self.k_proj = bnb.nn.Linear8bitLt(embed_dim, embed_dim, bias=bias) self.k_proj = ml.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = bnb.nn.Linear8bitLt(embed_dim, embed_dim, bias=bias) self.v_proj = ml.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = bnb.nn.Linear8bitLt(embed_dim, embed_dim, bias=bias) self.q_proj = ml.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = bnb.nn.Linear8bitLt(embed_dim, embed_dim, bias=bias) self.out_proj = ml.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
@ -183,10 +183,10 @@ class Wav2Vec2FeedForward(nn.Module):
super().__init__() super().__init__()
self.intermediate_dropout = nn.Dropout(dropout) self.intermediate_dropout = nn.Dropout(dropout)
self.intermediate_dense = bnb.nn.Linear8bitLt(hidden_size, intermediate_size) self.intermediate_dense = ml.Linear(hidden_size, intermediate_size)
self.intermediate_act_fn = F.gelu self.intermediate_act_fn = F.gelu
self.output_dense = bnb.nn.Linear8bitLt(intermediate_size, hidden_size) self.output_dense = ml.Linear(intermediate_size, hidden_size)
self.output_dropout = nn.Dropout(dropout) self.output_dropout = nn.Dropout(dropout)
def forward(self, hidden_states): def forward(self, hidden_states):
@ -430,7 +430,7 @@ class Mel2Vec(nn.Module):
k = math.sqrt(1 / module.projection.in_features) k = math.sqrt(1 / module.projection.in_features)
nn.init.uniform_(module.projection.weight, a=-k, b=k) nn.init.uniform_(module.projection.weight, a=-k, b=k)
nn.init.uniform_(module.projection.bias, a=-k, b=k) nn.init.uniform_(module.projection.bias, a=-k, b=k)
elif isinstance(module, bnb.nn.Linear8bitLt): elif isinstance(module, ml.Linear):
if self.disable_custom_linear_init: if self.disable_custom_linear_init:
return return
module.weight.data.normal_(mean=0.0, std=self.linear_init_scale) module.weight.data.normal_(mean=0.0, std=self.linear_init_scale)
@ -511,7 +511,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
self.codevectors = nn.Parameter( self.codevectors = nn.Parameter(
torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups) torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups)
) )
self.weight_proj = bnb.nn.Linear8bitLt(proj_dim, self.num_groups * self.num_vars) self.weight_proj = ml.Linear(proj_dim, self.num_groups * self.num_vars)
# can be decayed for training # can be decayed for training
self.temperature = 2 self.temperature = 2
@ -607,8 +607,8 @@ class ContrastiveTrainingWrapper(nn.Module):
self.inp_length_factor = inp_length_multiplier self.inp_length_factor = inp_length_multiplier
# make sure that project_hid & project_q are initialized like normal linear layers # make sure that project_hid & project_q are initialized like normal linear layers
self.project_hid = bnb.nn.Linear8bitLt(inner_dim, self.quantizer.codevector_dim) self.project_hid = ml.Linear(inner_dim, self.quantizer.codevector_dim)
self.project_q = bnb.nn.Linear8bitLt(self.quantizer.codevector_dim, self.quantizer.codevector_dim) self.project_q = ml.Linear(self.quantizer.codevector_dim, self.quantizer.codevector_dim)
self.reconstruction = do_reconstruction_loss self.reconstruction = do_reconstruction_loss
if do_reconstruction_loss: if do_reconstruction_loss:

View File

@ -2,7 +2,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from transformers import GPT2Config, GPT2Model from transformers import GPT2Config, GPT2Model
import bitsandbytes as bnb import torch_intermediary as ml
from models.arch_util import AttentionBlock, ResBlock from models.arch_util import AttentionBlock, ResBlock
from models.audio.tts.lucidrains_dvae import DiscreteVAE from models.audio.tts.lucidrains_dvae import DiscreteVAE
@ -57,8 +57,8 @@ class ConditioningAR(nn.Module):
del self.gpt.wte # Unused, we'll do our own embeddings. del self.gpt.wte # Unused, we'll do our own embeddings.
# nn.Embedding # nn.Embedding
self.embeddings = bnb.nn.StableEmbedding(num_vectors, dim) self.embeddings = ml.Embedding(num_vectors, dim)
self.head = bnb.nn.Linear8bitLt(dim, num_vectors) self.head = ml.Linear(dim, num_vectors)
def forward(self, cheater_codes, conditioning, code_lengths=None, return_latent=False): def forward(self, cheater_codes, conditioning, code_lengths=None, return_latent=False):
unused_params = [] unused_params = []

View File

@ -17,7 +17,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import bitsandbytes as bnb import torch_intermediary as ml
from math import sqrt from math import sqrt
@ -25,7 +25,7 @@ from torch.utils.checkpoint import checkpoint
from trainer.networks import register_model from trainer.networks import register_model
Linear = bnb.nn.Linear8bitLt Linear = ml.Linear
ConvTranspose2d = nn.ConvTranspose2d ConvTranspose2d = nn.ConvTranspose2d

View File

@ -4,7 +4,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import autocast from torch import autocast
import bitsandbytes as bnb import torch_intermediary as ml
from models.arch_util import ResBlock from models.arch_util import ResBlock
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
@ -24,7 +24,7 @@ class MultiGroupEmbedding(nn.Module):
def __init__(self, tokens, groups, dim): def __init__(self, tokens, groups, dim):
super().__init__() super().__init__()
# nn.Embedding # nn.Embedding
self.m = nn.ModuleList([bnb.nn.StableEmbedding(tokens, dim // groups) for _ in range(groups)]) self.m = nn.ModuleList([ml.Embedding(tokens, dim // groups) for _ in range(groups)])
def forward(self, x): def forward(self, x):
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)] h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
@ -161,7 +161,7 @@ class FlatDiffusion(nn.Module):
# transformer network. # transformer network.
if in_groups is None: if in_groups is None:
# nn.Embedding # nn.Embedding
self.embeddings = bnb.nn.StableEmbedding(token_count, model_channels) self.embeddings = ml.Embedding(token_count, model_channels)
else: else:
self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels) self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels)
self.latent_conditioner = nn.Sequential( self.latent_conditioner = nn.Sequential(

View File

@ -2,7 +2,7 @@ import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import GPT2Config, GPT2Model from transformers import GPT2Config, GPT2Model
import bitsandbytes as bnb import torch_intermediary as ml
from models.arch_util import AttentionBlock, ResBlock from models.arch_util import AttentionBlock, ResBlock
from models.audio.music.music_quantizer import MusicQuantizer from models.audio.music.music_quantizer import MusicQuantizer
@ -138,8 +138,8 @@ class GptMusicLower(nn.Module):
del self.gpt.wte # Unused, we'll do our own embeddings. del self.gpt.wte # Unused, we'll do our own embeddings.
# nn.Embedding # nn.Embedding
self.embeddings = nn.ModuleList([bnb.nn.StableEmbedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)]) self.embeddings = nn.ModuleList([ml.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)])
self.heads = nn.ModuleList([bnb.nn.Linear8bitLt(dim, num_target_vectors) for _ in range(num_vaes)]) self.heads = nn.ModuleList([ml.Linear(dim, num_target_vectors) for _ in range(num_vaes)])
def forward(self, mel, conditioning, return_latent=False): def forward(self, mel, conditioning, return_latent=False):
unused_params = [] unused_params = []
@ -241,8 +241,8 @@ class GptMusicUpper(nn.Module):
del self.gpt.wte # Unused, we'll do our own embeddings. del self.gpt.wte # Unused, we'll do our own embeddings.
# nn.Embedding # nn.Embedding
self.embeddings = nn.ModuleList([bnb.nn.StableEmbedding(num_upper_vectors, dim // num_upper_groups) for _ in range(num_upper_groups)]) self.embeddings = nn.ModuleList([ml.Embedding(num_upper_vectors, dim // num_upper_groups) for _ in range(num_upper_groups)])
self.heads = nn.ModuleList([bnb.nn.Linear8bitLt(dim, num_upper_vectors) for _ in range(num_upper_groups)]) self.heads = nn.ModuleList([ml.Linear(dim, num_upper_vectors) for _ in range(num_upper_groups)])
def forward(self, mel, conditioning, return_latent=False): def forward(self, mel, conditioning, return_latent=False):

View File

@ -2,7 +2,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from transformers import GPT2Config, GPT2Model from transformers import GPT2Config, GPT2Model
import bitsandbytes as bnb import torch_intermediary as ml
from models.arch_util import AttentionBlock, ResBlock from models.arch_util import AttentionBlock, ResBlock
from models.audio.tts.lucidrains_dvae import DiscreteVAE from models.audio.tts.lucidrains_dvae import DiscreteVAE
@ -75,8 +75,8 @@ class GptMusicLower(nn.Module):
del self.gpt.wte # Unused, we'll do our own embeddings. del self.gpt.wte # Unused, we'll do our own embeddings.
# nn.Embedding # nn.Embedding
self.embeddings = nn.ModuleList([bnb.nn.StableEmbedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)]) self.embeddings = nn.ModuleList([ml.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)])
self.heads = nn.ModuleList([bnb.nn.Linear8bitLt(dim, num_target_vectors) for _ in range(num_vaes)]) self.heads = nn.ModuleList([ml.Linear(dim, num_target_vectors) for _ in range(num_vaes)])
def forward(self, mel, return_latent=False): def forward(self, mel, return_latent=False):
unused_params = [] unused_params = []

View File

@ -3,7 +3,7 @@ import functools
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import bitsandbytes as bnb import torch_intermediary as ml
from models.diffusion.nn import timestep_embedding from models.diffusion.nn import timestep_embedding
from models.lucidrains.vq import VectorQuantize from models.lucidrains.vq import VectorQuantize
@ -22,8 +22,8 @@ class SelfClassifyingHead(nn.Module):
use_rmsnorm=True, ff_glu=True, do_checkpointing=False) use_rmsnorm=True, ff_glu=True, do_checkpointing=False)
self.quantizer = VectorQuantize(out_dim, classes, use_cosine_sim=False, threshold_ema_dead_code=2, self.quantizer = VectorQuantize(out_dim, classes, use_cosine_sim=False, threshold_ema_dead_code=2,
sample_codebook_temp=init_temperature) sample_codebook_temp=init_temperature)
self.to_output = bnb.nn.Linear8bitLt(dim, out_dim) self.to_output = ml.Linear(dim, out_dim)
self.to_decoder = bnb.nn.Linear8bitLt(out_dim, dim) self.to_decoder = ml.Linear(out_dim, dim)
def do_ar_step(self, x, used_codes): def do_ar_step(self, x, used_codes):
h = self.dec(x) h = self.dec(x)
@ -91,7 +91,7 @@ class InstrumentQuantizer(nn.Module):
""" """
super().__init__() super().__init__()
self.op_dim = op_dim self.op_dim = op_dim
self.proj = bnb.nn.Linear8bitLt(op_dim, dim) self.proj = ml.Linear(op_dim, dim)
self.encoder = nn.ModuleList([VectorResBlock(dim, dropout) for _ in range(enc_depth)]) self.encoder = nn.ModuleList([VectorResBlock(dim, dropout) for _ in range(enc_depth)])
self.heads = SelfClassifyingHead(dim, num_classes, op_dim, head_depth, class_seq_len, dropout, max_temp) self.heads = SelfClassifyingHead(dim, num_classes, op_dim, head_depth, class_seq_len, dropout, max_temp)
self.min_gumbel_temperature = min_temp self.min_gumbel_temperature = min_temp

View File

@ -1,7 +1,7 @@
import torch import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
import bitsandbytes as bnb import torch_intermediary as ml
from transformers import GPT2Config, GPT2Model from transformers import GPT2Config, GPT2Model
from trainer.networks import register_model from trainer.networks import register_model
@ -19,8 +19,8 @@ class Mel2VecCodesGpt(nn.Module):
self.gpt = GPT2Model(self.config) self.gpt = GPT2Model(self.config)
del self.gpt.wte # Unused, we'll do our own embeddings. del self.gpt.wte # Unused, we'll do our own embeddings.
# nn.Embedding # nn.Embedding
self.embeddings = nn.ModuleList([bnb.nn.StableEmbedding(num_vectors, dim//num_groups) for _ in range(num_groups)]) self.embeddings = nn.ModuleList([ml.Embedding(num_vectors, dim//num_groups) for _ in range(num_groups)])
self.heads = nn.ModuleList([bnb.nn.Linear8bitLt(dim, num_vectors) for _ in range(num_groups)]) self.heads = nn.ModuleList([ml.Linear(dim, num_vectors) for _ in range(num_groups)])
def forward(self, codes): def forward(self, codes):
assert codes.shape[-1] == self.num_groups assert codes.shape[-1] == self.num_groups

View File

@ -3,7 +3,7 @@ import functools
import torch import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
import bitsandbytes as bnb import torch_intermediary as ml
from models.arch_util import zero_module from models.arch_util import zero_module
from models.vqvae.vqvae import Quantize from models.vqvae.vqvae import Quantize
@ -76,7 +76,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
self.codevectors = nn.Parameter( self.codevectors = nn.Parameter(
torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups) torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups)
) )
self.weight_proj = bnb.nn.Linear8bitLt(proj_dim, self.num_groups * self.num_vars) self.weight_proj = ml.Linear(proj_dim, self.num_groups * self.num_vars)
# can be decayed for training # can be decayed for training
self.temperature = 2 self.temperature = 2

View File

@ -3,7 +3,7 @@ import functools
import torch import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
import bitsandbytes as bnb import torch_intermediary as ml
from models.arch_util import zero_module from models.arch_util import zero_module
from models.vqvae.vqvae import Quantize from models.vqvae.vqvae import Quantize
@ -88,7 +88,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
self.codevectors = nn.Parameter( self.codevectors = nn.Parameter(
torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups) torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups)
) )
self.weight_proj = bnb.nn.Linear8bitLt(proj_dim, self.num_groups * self.num_vars) self.weight_proj = ml.Linear(proj_dim, self.num_groups * self.num_vars)
# can be decayed for training # can be decayed for training
self.temperature = 2 self.temperature = 2

View File

@ -8,7 +8,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchaudio import torchaudio
import torchvision import torchvision
import bitsandbytes as bnb import torch_intermediary as ml
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import TimestepBlock from models.diffusion.unet_diffusion import TimestepBlock
@ -56,12 +56,12 @@ class ConcatAttentionBlock(TimestepBlock):
self.prenorm = RMSScaleShiftNorm(trunk_dim, embed_dim=time_embed_dim, bias=False) self.prenorm = RMSScaleShiftNorm(trunk_dim, embed_dim=time_embed_dim, bias=False)
if cond_projection: if cond_projection:
self.tdim = trunk_dim+cond_dim_hidden self.tdim = trunk_dim+cond_dim_hidden
self.cond_project = bnb.nn.Linear8bitLt(cond_dim_in, cond_dim_hidden) self.cond_project = ml.Linear(cond_dim_in, cond_dim_hidden)
else: else:
self.tdim = trunk_dim self.tdim = trunk_dim
self.block1 = SubBlock(self.tdim, contraction_dim, heads, dropout, use_conv) self.block1 = SubBlock(self.tdim, contraction_dim, heads, dropout, use_conv)
self.block2 = SubBlock(self.tdim+contraction_dim*2, contraction_dim, heads, dropout, use_conv) self.block2 = SubBlock(self.tdim+contraction_dim*2, contraction_dim, heads, dropout, use_conv)
self.out = bnb.nn.Linear8bitLt(contraction_dim*4, trunk_dim, bias=False) self.out = ml.Linear(contraction_dim*4, trunk_dim, bias=False)
self.out.weight.data.zero_() self.out.weight.data.zero_()
def forward(self, x, cond, timestep_emb, rotary_emb): def forward(self, x, cond, timestep_emb, rotary_emb):
@ -89,7 +89,7 @@ class ConditioningEncoder(nn.Module):
self.init = nn.Conv1d(cond_dim, embedding_dim, kernel_size=1) self.init = nn.Conv1d(cond_dim, embedding_dim, kernel_size=1)
self.time_proj = time_proj self.time_proj = time_proj
if time_proj: if time_proj:
self.time_proj = bnb.nn.Linear8bitLt(time_embed_dim, embedding_dim) self.time_proj = ml.Linear(time_embed_dim, embedding_dim)
self.attn = Encoder( self.attn = Encoder(
dim=embedding_dim, dim=embedding_dim,
depth=attn_blocks, depth=attn_blocks,

View File

@ -4,7 +4,7 @@ from time import time
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import bitsandbytes as bnb import torch_intermediary as ml
from models.arch_util import ResBlock from models.arch_util import ResBlock
from models.audio.music.gpt_music2 import UpperEncoder, GptMusicLower from models.audio.music.gpt_music2 import UpperEncoder, GptMusicLower
@ -29,7 +29,7 @@ class MultiGroupEmbedding(nn.Module):
def __init__(self, tokens, groups, dim): def __init__(self, tokens, groups, dim):
super().__init__() super().__init__()
# nn.Embedding # nn.Embedding
self.m = nn.ModuleList([bnb.nn.StableEmbedding(tokens, dim // groups) for _ in range(groups)]) self.m = nn.ModuleList([ml.Embedding(tokens, dim // groups) for _ in range(groups)])
def forward(self, x): def forward(self, x):
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)] h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
@ -70,7 +70,7 @@ class ConcatAttentionBlock(TimestepBlock):
self.prenorm = RMSScaleShiftNorm(trunk_dim, embed_dim=time_embed_dim, bias=False) self.prenorm = RMSScaleShiftNorm(trunk_dim, embed_dim=time_embed_dim, bias=False)
self.block1 = SubBlock(trunk_dim, contraction_dim, heads, dropout) self.block1 = SubBlock(trunk_dim, contraction_dim, heads, dropout)
self.block2 = SubBlock(trunk_dim+contraction_dim*2, contraction_dim, heads, dropout) self.block2 = SubBlock(trunk_dim+contraction_dim*2, contraction_dim, heads, dropout)
self.out = bnb.nn.Linear8bitLt(contraction_dim*4, trunk_dim, bias=False) self.out = ml.Linear(contraction_dim*4, trunk_dim, bias=False)
self.out.weight.data.zero_() self.out.weight.data.zero_()
def forward(self, x, timestep_emb, rotary_emb): def forward(self, x, timestep_emb, rotary_emb):
@ -131,7 +131,7 @@ class TransformerDiffusion(nn.Module):
) )
prenet_heads = prenet_channels//64 prenet_heads = prenet_channels//64
self.input_converter = bnb.nn.Linear8bitLt(input_vec_dim, prenet_channels) self.input_converter = ml.Linear(input_vec_dim, prenet_channels)
self.code_converter = Encoder( self.code_converter = Encoder(
dim=prenet_channels, dim=prenet_channels,
depth=prenet_layers, depth=prenet_layers,
@ -147,7 +147,7 @@ class TransformerDiffusion(nn.Module):
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels)) self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels))
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
self.intg = bnb.nn.Linear8bitLt(prenet_channels*2, model_channels) self.intg = ml.Linear(prenet_channels*2, model_channels)
self.layers = TimestepRotaryEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_embed_dim, num_heads, dropout) for _ in range(num_layers)]) self.layers = TimestepRotaryEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_embed_dim, num_heads, dropout) for _ in range(num_layers)])
self.out = nn.Sequential( self.out = nn.Sequential(

View File

@ -5,7 +5,7 @@ from random import randrange
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import bitsandbytes as bnb import torch_intermediary as ml
from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask, cGLU, \ from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask, cGLU, \
RelativeQKBias RelativeQKBias
@ -71,13 +71,13 @@ class ConditioningEncoder(nn.Module):
attn = [] attn = []
self.init = nn.Conv1d(spec_dim, hidden_dim, kernel_size=5, stride=2) self.init = nn.Conv1d(spec_dim, hidden_dim, kernel_size=5, stride=2)
# nn.Embedding # nn.Embedding
self.resolution_embedding = bnb.nn.StableEmbedding(num_resolutions, hidden_dim) self.resolution_embedding = ml.Embedding(num_resolutions, hidden_dim)
self.resolution_embedding.weight.data.mul(.1) # Reduces the relative influence of this embedding from the start. self.resolution_embedding.weight.data.mul(.1) # Reduces the relative influence of this embedding from the start.
for a in range(attn_blocks): for a in range(attn_blocks):
attn.append(AttentionBlock(hidden_dim, num_attn_heads, do_checkpoint=do_checkpointing)) attn.append(AttentionBlock(hidden_dim, num_attn_heads, do_checkpoint=do_checkpointing))
attn.append(ResBlock(hidden_dim, dims=1, checkpointing_enabled=do_checkpointing)) attn.append(ResBlock(hidden_dim, dims=1, checkpointing_enabled=do_checkpointing))
self.attn = nn.Sequential(*attn) self.attn = nn.Sequential(*attn)
self.out = bnb.nn.Linear8bitLt(hidden_dim, out_dim, bias=False) self.out = ml.Linear(hidden_dim, out_dim, bias=False)
self.dim = hidden_dim self.dim = hidden_dim
self.do_checkpointing = do_checkpointing self.do_checkpointing = do_checkpointing
@ -134,7 +134,7 @@ class TransformerDiffusion(nn.Module):
linear(time_embed_dim, time_proj_dim), linear(time_embed_dim, time_proj_dim),
) )
# nn.Embedding # nn.Embedding
self.resolution_embed = bnb.nn.StableEmbedding(resolution_steps, time_proj_dim) self.resolution_embed = ml.Embedding(resolution_steps, time_proj_dim)
self.conditioning_encoder = ConditioningEncoder(in_channels, model_channels, cond_proj_dim, resolution_steps, num_attn_heads=model_channels//64) self.conditioning_encoder = ConditioningEncoder(in_channels, model_channels, cond_proj_dim, resolution_steps, num_attn_heads=model_channels//64)
self.unconditioned_embedding = nn.Parameter(torch.randn(1,cond_proj_dim)) self.unconditioned_embedding = nn.Parameter(torch.randn(1,cond_proj_dim))

View File

@ -8,7 +8,7 @@ import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchvision # For debugging, not actually used. import torchvision # For debugging, not actually used.
import bitsandbytes as bnb import torch_intermediary as ml
from models.audio.music.gpt_music import GptMusicLower from models.audio.music.gpt_music import GptMusicLower
from models.audio.music.music_quantizer import MusicQuantizer from models.audio.music.music_quantizer import MusicQuantizer
@ -491,7 +491,7 @@ class UNetMusicModel(nn.Module):
) )
if self.ar_prior: if self.ar_prior:
self.ar_input = bnb.nn.Linear8bitLt(input_vec_dim, model_channels) self.ar_input = ml.Linear(input_vec_dim, model_channels)
self.ar_prior_intg = Encoder( self.ar_prior_intg = Encoder(
dim=model_channels, dim=model_channels,
depth=4, depth=4,
@ -505,7 +505,7 @@ class UNetMusicModel(nn.Module):
ff_mult=1, ff_mult=1,
) )
else: else:
self.input_converter = bnb.nn.Linear8bitLt(input_vec_dim, model_channels) self.input_converter = ml.Linear(input_vec_dim, model_channels)
self.code_converter = Encoder( self.code_converter = Encoder(
dim=model_channels, dim=model_channels,
depth=4, depth=4,
@ -523,7 +523,7 @@ class UNetMusicModel(nn.Module):
if self.num_classes is not None: if self.num_classes is not None:
# nn.Embedding # nn.Embedding
self.label_emb = bnb.nn.StableEmbedding(num_classes, time_embed_dim) self.label_emb = ml.Embedding(num_classes, time_embed_dim)
self.use_raw_y_as_embedding = use_raw_y_as_embedding self.use_raw_y_as_embedding = use_raw_y_as_embedding
assert not ((self.num_classes is not None) and use_raw_y_as_embedding) # These are mutually-exclusive. assert not ((self.num_classes is not None) and use_raw_y_as_embedding) # These are mutually-exclusive.

View File

@ -3,7 +3,7 @@ from random import random
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import bitsandbytes as bnb import torch_intermediary as ml
from models.audio.tts.unet_diffusion_tts7 import CheckpointedLayer from models.audio.tts.unet_diffusion_tts7 import CheckpointedLayer
from models.lucidrains.x_transformers import Encoder from models.lucidrains.x_transformers import Encoder
@ -38,11 +38,11 @@ class CtcCodeGenerator(nn.Module):
pred_codes = (max_pad+1)*(max_repeat+1) pred_codes = (max_pad+1)*(max_repeat+1)
# nn.Embedding # nn.Embedding
self.position_embedding = bnb.nn.StableEmbedding(max_length, model_dim) self.position_embedding = ml.Embedding(max_length, model_dim)
# nn.Embedding # nn.Embedding
self.codes_embedding = bnb.nn.StableEmbedding(ctc_codes, model_dim) self.codes_embedding = ml.Embedding(ctc_codes, model_dim)
# nn.Embedding # nn.Embedding
self.recursive_embedding = bnb.nn.StableEmbedding(pred_codes, model_dim) self.recursive_embedding = ml.Embedding(pred_codes, model_dim)
self.mask_embedding = nn.Parameter(torch.randn(model_dim)) self.mask_embedding = nn.Parameter(torch.randn(model_dim))
self.encoder = Encoder( self.encoder = Encoder(
dim=model_dim, dim=model_dim,
@ -54,8 +54,8 @@ class CtcCodeGenerator(nn.Module):
ff_glu=True, ff_glu=True,
rotary_pos_emb=True, rotary_pos_emb=True,
) )
self.pred_head = bnb.nn.Linear8bitLt(model_dim, pred_codes) self.pred_head = ml.Linear(model_dim, pred_codes)
self.confidence_head = bnb.nn.Linear8bitLt(model_dim, 1) self.confidence_head = ml.Linear(model_dim, 1)
def inference(self, codes, pads, repeats): def inference(self, codes, pads, repeats):
position_h = self.position_embedding(torch.arange(0, codes.shape[-1], device=codes.device)) position_h = self.position_embedding(torch.arange(0, codes.shape[-1], device=codes.device))

View File

@ -5,7 +5,7 @@ from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
import bitsandbytes as bnb import torch_intermediary as ml
from x_transformers.x_transformers import groupby_prefix_and_trim, FixedPositionalEmbedding, default, RotaryEmbedding, \ from x_transformers.x_transformers import groupby_prefix_and_trim, FixedPositionalEmbedding, default, RotaryEmbedding, \
DEFAULT_DIM_HEAD, RelativePositionBias, LearnedAlibiPositionalBias, AlibiPositionalBias, ScaleNorm, RMSNorm, \ DEFAULT_DIM_HEAD, RelativePositionBias, LearnedAlibiPositionalBias, AlibiPositionalBias, ScaleNorm, RMSNorm, \
@ -18,7 +18,7 @@ class TimeIntegrationBlock(nn.Module):
super().__init__() super().__init__()
self.emb_layers = nn.Sequential( self.emb_layers = nn.Sequential(
nn.SiLU(), nn.SiLU(),
bnb.nn.Linear8bitLt( ml.Linear(
time_emb_dim, time_emb_dim,
2 * dim 2 * dim
), ),

View File

@ -1,6 +1,6 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import bitsandbytes as bnb import torch_intermediary as ml
from models.diffusion.nn import normalization, conv_nd, zero_module from models.diffusion.nn import normalization, conv_nd, zero_module
@ -139,7 +139,7 @@ class AudioMiniEncoderWithClassifierHead(nn.Module):
def __init__(self, classes, distribute_zero_label=True, **kwargs): def __init__(self, classes, distribute_zero_label=True, **kwargs):
super().__init__() super().__init__()
self.enc = AudioMiniEncoder(**kwargs) self.enc = AudioMiniEncoder(**kwargs)
self.head = bnb.nn.Linear8bitLt(self.enc.dim, classes) self.head = ml.Linear(self.enc.dim, classes)
self.num_classes = classes self.num_classes = classes
self.distribute_zero_label = distribute_zero_label self.distribute_zero_label = distribute_zero_label
@ -184,7 +184,7 @@ class QueryProvidedAttentionBlock(nn.Module):
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels self.num_heads = channels // num_head_channels
self.norm = normalization(channels) self.norm = normalization(channels)
self.q = bnb.nn.Linear8bitLt(channels, channels) self.q = ml.Linear(channels, channels)
self.qnorm = nn.LayerNorm(channels) self.qnorm = nn.LayerNorm(channels)
self.kv = conv_nd(1, channels, channels*2, 1) self.kv = conv_nd(1, channels, channels*2, 1)
if use_new_attention_order: if use_new_attention_order:

View File

@ -3,7 +3,7 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import bitsandbytes as bnb import torch_intermediary as ml
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import opt_get from utils.util import opt_get
@ -45,7 +45,7 @@ class RandomLatentConverter(nn.Module):
def __init__(self, channels): def __init__(self, channels):
super().__init__() super().__init__()
self.layers = nn.Sequential(*[EqualLinear(channels, channels, lr_mul=.1) for _ in range(5)], self.layers = nn.Sequential(*[EqualLinear(channels, channels, lr_mul=.1) for _ in range(5)],
bnb.nn.Linear8bitLt(channels, channels)) ml.Linear(channels, channels))
self.channels = channels self.channels = channels
def forward(self, ref): def forward(self, ref):

View File

@ -3,13 +3,13 @@ from librosa.filters import mel as librosa_mel_fn
from models.audio.tts.tacotron2.audio_processing import dynamic_range_compression from models.audio.tts.tacotron2.audio_processing import dynamic_range_compression
from models.audio.tts.tacotron2.audio_processing import dynamic_range_decompression from models.audio.tts.tacotron2.audio_processing import dynamic_range_decompression
from models.audio.tts.tacotron2.stft import STFT from models.audio.tts.tacotron2.stft import STFT
import bitsandbytes as bnb import torch_intermediary as ml
class LinearNorm(torch.nn.Module): class LinearNorm(torch.nn.Module):
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
super(LinearNorm, self).__init__() super(LinearNorm, self).__init__()
self.linear_layer = torch.bnb.nn.Linear8bitLt(in_dim, out_dim, bias=bias) self.linear_layer = torch.ml.Linear(in_dim, out_dim, bias=bias)
torch.nn.init.xavier_uniform_( torch.nn.init.xavier_uniform_(
self.linear_layer.weight, self.linear_layer.weight,

View File

@ -8,7 +8,7 @@ from models.audio.tts.tacotron2.layers import ConvNorm, LinearNorm
from models.audio.tts.tacotron2.hparams import create_hparams from models.audio.tts.tacotron2.hparams import create_hparams
from trainer.networks import register_model from trainer.networks import register_model
from models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths from models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths
import bitsandbytes as bnb import torch_intermediary as ml
class LocationLayer(nn.Module): class LocationLayer(nn.Module):
@ -465,7 +465,7 @@ class Tacotron2(nn.Module):
self.n_mel_channels = hparams.n_mel_channels self.n_mel_channels = hparams.n_mel_channels
self.n_frames_per_step = hparams.n_frames_per_step self.n_frames_per_step = hparams.n_frames_per_step
# nn.Embedding # nn.Embedding
self.embedding = bnb.nn.StableEmbedding( self.embedding = ml.Embedding(
hparams.n_symbols, hparams.symbols_embedding_dim) hparams.n_symbols, hparams.symbols_embedding_dim)
std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim)) std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
val = sqrt(3.0) * std # uniform bounds for std val = sqrt(3.0) * std # uniform bounds for std

View File

@ -13,7 +13,7 @@ from models.audio.tts.tacotron2.tacotron2 import Attention, Encoder
from trainer.networks import register_model from trainer.networks import register_model
from models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths from models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths
from utils.util import checkpoint from utils.util import checkpoint
import bitsandbytes as bnb import torch_intermediary as ml
@ -187,7 +187,7 @@ class WaveTacotron2(nn.Module):
self.n_mel_channels = hparams.n_mel_channels self.n_mel_channels = hparams.n_mel_channels
self.n_frames_per_step = hparams.n_frames_per_step self.n_frames_per_step = hparams.n_frames_per_step
# nn.Embedding # nn.Embedding
self.embedding = bnb.nn.StableEmbedding( self.embedding = ml.Embedding(
hparams.n_symbols, hparams.symbols_embedding_dim) hparams.n_symbols, hparams.symbols_embedding_dim)
std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim)) std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
val = sqrt(3.0) * std # uniform bounds for std val = sqrt(3.0) * std # uniform bounds for std

View File

@ -25,7 +25,7 @@ import random
from time import time from time import time
import torch import torch
import torch.nn as nn import torch.nn as nn
import bitsandbytes as bnb import torch_intermediary as ml
from tqdm import tqdm from tqdm import tqdm
@ -37,7 +37,7 @@ class LearnedPositionEmbeddings(nn.Module):
def __init__(self, seq_len, model_dim, init=.02, relative=False): def __init__(self, seq_len, model_dim, init=.02, relative=False):
super().__init__() super().__init__()
# nn.Embedding # nn.Embedding
self.emb = bnb.nn.StableEmbedding(seq_len, model_dim) self.emb = ml.Embedding(seq_len, model_dim)
# Initializing this way is standard for GPT-2 # Initializing this way is standard for GPT-2
self.emb.weight.data.normal_(mean=0.0, std=init) self.emb.weight.data.normal_(mean=0.0, std=init)
self.relative = relative self.relative = relative

View File

@ -7,7 +7,7 @@ from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlo
from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import checkpoint from utils.util import checkpoint
import bitsandbytes as bnb import torch_intermediary as ml
def is_latent(t): def is_latent(t):
@ -21,7 +21,7 @@ class MultiGroupEmbedding(nn.Module):
def __init__(self, tokens, groups, dim): def __init__(self, tokens, groups, dim):
super().__init__() super().__init__()
# nn.Embedding # nn.Embedding
self.m = nn.ModuleList([bnb.nn.StableEmbedding(tokens, dim // groups) for _ in range(groups)]) self.m = nn.ModuleList([ml.Embedding(tokens, dim // groups) for _ in range(groups)])
def forward(self, x): def forward(self, x):
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)] h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
@ -102,9 +102,9 @@ class TransformerDiffusionTTS(nn.Module):
ff_glu=True, ff_glu=True,
rotary_pos_emb=True, rotary_pos_emb=True,
) )
self.clvp_encoder = bnb.nn.Linear8bitLt(clvp_in_dim, model_channels) self.clvp_encoder = ml.Linear(clvp_in_dim, model_channels)
# nn.Embedding # nn.Embedding
self.type_embedding = bnb.nn.StableEmbedding(types, model_channels) self.type_embedding = ml.Embedding(types, model_channels)
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed. # Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally # This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
@ -112,7 +112,7 @@ class TransformerDiffusionTTS(nn.Module):
# transformer network. # transformer network.
if in_groups is None: if in_groups is None:
# nn.Embedding # nn.Embedding
self.embeddings = bnb.nn.StableEmbedding(token_count, model_channels) self.embeddings = ml.Embedding(token_count, model_channels)
else: else:
self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels) self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels)
self.latent_conditioner = nn.Sequential( self.latent_conditioner = nn.Sequential(
@ -144,7 +144,7 @@ class TransformerDiffusionTTS(nn.Module):
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1) self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
self.intg = bnb.nn.Linear8bitLt(model_channels*2, model_channels) self.intg = ml.Linear(model_channels*2, model_channels)
self.layers = TimestepRotaryEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers)]) self.layers = TimestepRotaryEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers)])
self.out = nn.Sequential( self.out = nn.Sequential(

View File

@ -1,7 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import bitsandbytes as bnb import torch_intermediary as ml
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock
@ -21,7 +21,7 @@ class MultiGroupEmbedding(nn.Module):
def __init__(self, tokens, groups, dim): def __init__(self, tokens, groups, dim):
super().__init__() super().__init__()
# nn.Embedding # nn.Embedding
self.m = nn.ModuleList([bnb.nn.StableEmbedding(tokens, dim // groups) for _ in range(groups)]) self.m = nn.ModuleList([ml.Embedding(tokens, dim // groups) for _ in range(groups)])
def forward(self, x): def forward(self, x):
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)] h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
@ -42,7 +42,7 @@ class DietAttentionBlock(TimestepBlock):
def __init__(self, in_dim, dim, heads, dropout): def __init__(self, in_dim, dim, heads, dropout):
super().__init__() super().__init__()
self.rms_scale_norm = RMSScaleShiftNorm(in_dim) self.rms_scale_norm = RMSScaleShiftNorm(in_dim)
self.proj = bnb.nn.Linear8bitLt(in_dim, dim) self.proj = ml.Linear(in_dim, dim)
self.attn = Attention(dim, heads=heads, causal=False, dropout=dropout) self.attn = Attention(dim, heads=heads, causal=False, dropout=dropout)
self.ff = FeedForward(dim, in_dim, mult=1, dropout=dropout, zero_init_output=True) self.ff = FeedForward(dim, in_dim, mult=1, dropout=dropout, zero_init_output=True)
@ -107,9 +107,9 @@ class TransformerDiffusionTTS(nn.Module):
ff_glu=True, ff_glu=True,
rotary_pos_emb=True, rotary_pos_emb=True,
) )
self.clvp_encoder = bnb.nn.Linear8bitLt(clvp_in_dim, prenet_channels) self.clvp_encoder = ml.Linear(clvp_in_dim, prenet_channels)
# nn.Embedding # nn.Embedding
self.type_embedding = bnb.nn.StableEmbedding(types, prenet_channels) self.type_embedding = ml.Embedding(types, prenet_channels)
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed. # Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally # This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
@ -117,7 +117,7 @@ class TransformerDiffusionTTS(nn.Module):
# transformer network. # transformer network.
if in_groups is None: if in_groups is None:
# nn.Embedding # nn.Embedding
self.embeddings = bnb.nn.StableEmbedding(token_count, prenet_channels) self.embeddings = ml.Embedding(token_count, prenet_channels)
else: else:
self.embeddings = MultiGroupEmbedding(token_count, in_groups, prenet_channels) self.embeddings = MultiGroupEmbedding(token_count, in_groups, prenet_channels)
self.latent_conditioner = nn.Sequential( self.latent_conditioner = nn.Sequential(
@ -148,8 +148,8 @@ class TransformerDiffusionTTS(nn.Module):
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels)) self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels))
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
self.cond_intg = bnb.nn.Linear8bitLt(prenet_channels*4, model_channels) self.cond_intg = ml.Linear(prenet_channels*4, model_channels)
self.intg = bnb.nn.Linear8bitLt(prenet_channels*2, model_channels) self.intg = ml.Linear(prenet_channels*2, model_channels)
self.layers = TimestepRotaryEmbedSequential(*[DietAttentionBlock(model_channels, block_channels, block_channels // 64, dropout) for _ in range(num_layers)]) self.layers = TimestepRotaryEmbedSequential(*[DietAttentionBlock(model_channels, block_channels, block_channels // 64, dropout) for _ in range(num_layers)])

View File

@ -5,7 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import autocast from torch import autocast
import bitsandbytes as bnb import torch_intermediary as ml
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \ from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \
@ -249,7 +249,7 @@ class DiffusionTts(nn.Module):
embedding_dim = model_channels * 8 embedding_dim = model_channels * 8
# nn.Embedding # nn.Embedding
self.code_embedding = bnb.nn.StableEmbedding(num_tokens+1, embedding_dim) self.code_embedding = ml.Embedding(num_tokens+1, embedding_dim)
self.contextual_embedder = AudioMiniEncoder(1, embedding_dim, base_channels=32, depth=6, resnet_blocks=1, self.contextual_embedder = AudioMiniEncoder(1, embedding_dim, base_channels=32, depth=6, resnet_blocks=1,
attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5) attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
self.conditioning_conv = nn.Conv1d(embedding_dim*3, embedding_dim, 1) self.conditioning_conv = nn.Conv1d(embedding_dim*3, embedding_dim, 1)
@ -257,7 +257,7 @@ class DiffusionTts(nn.Module):
self.enable_unaligned_inputs = enabled_unaligned_inputs self.enable_unaligned_inputs = enabled_unaligned_inputs
if enabled_unaligned_inputs: if enabled_unaligned_inputs:
# nn.Embedding # nn.Embedding
self.unaligned_embedder = bnb.nn.StableEmbedding(num_unaligned_tokens, embedding_dim) self.unaligned_embedder = ml.Embedding(num_unaligned_tokens, embedding_dim)
self.unaligned_encoder = CheckpointedXTransformerEncoder( self.unaligned_encoder = CheckpointedXTransformerEncoder(
max_seq_len=-1, max_seq_len=-1,
use_pos_emb=False, use_pos_emb=False,

View File

@ -5,7 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import autocast from torch import autocast
from x_transformers import Encoder from x_transformers import Encoder
import bitsandbytes as bnb import torch_intermediary as ml
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \ from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \
@ -208,7 +208,7 @@ class DiffusionTts(nn.Module):
# transformer network. # transformer network.
self.code_converter = nn.Sequential( self.code_converter = nn.Sequential(
# nn.Embedding # nn.Embedding
bnb.nn.StableEmbedding(in_tokens, conditioning_dim), ml.Embedding(in_tokens, conditioning_dim),
CheckpointedXTransformerEncoder( CheckpointedXTransformerEncoder(
needs_permute=False, needs_permute=False,
max_seq_len=-1, max_seq_len=-1,

View File

@ -5,7 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import autocast from torch import autocast
import bitsandbytes as bnb import torch_intermediary as ml
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock, QKVAttentionLegacy from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock, QKVAttentionLegacy
@ -196,7 +196,7 @@ class DiffusionTtsFlat(nn.Module):
# transformer network. # transformer network.
# nn.Embedding # nn.Embedding
self.code_embedding = bnb.nn.StableEmbedding(in_tokens, model_channels) self.code_embedding = ml.Embedding(in_tokens, model_channels)
self.code_converter = nn.Sequential( self.code_converter = nn.Sequential(
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),

View File

@ -13,7 +13,7 @@ from models.lucidrains.x_transformers import RotaryEmbedding, apply_rotary_pos_e
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import opt_get from utils.util import opt_get
import bitsandbytes as bnb import torch_intermediary as ml
class ResBlock(nn.Module): class ResBlock(nn.Module):
""" """
@ -282,10 +282,10 @@ class UnifiedVoice(nn.Module):
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.average_conditioning_embeddings = average_conditioning_embeddings self.average_conditioning_embeddings = average_conditioning_embeddings
# nn.Embedding # nn.Embedding
self.text_embedding = bnb.nn.StableEmbedding(self.number_text_tokens, model_dim) self.text_embedding = ml.Embedding(self.number_text_tokens, model_dim)
if use_mel_codes_as_input: if use_mel_codes_as_input:
# nn.Embedding # nn.Embedding
self.mel_embedding = bnb.nn.StableEmbedding(self.number_mel_codes, model_dim) self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim)
else: else:
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1) self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \ self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
@ -298,8 +298,8 @@ class UnifiedVoice(nn.Module):
self.text_solo_embedding = 0 self.text_solo_embedding = 0
self.final_norm = nn.LayerNorm(model_dim) self.final_norm = nn.LayerNorm(model_dim)
self.text_head = bnb.nn.Linear8bitLt(model_dim, self.number_text_tokens) self.text_head = ml.Linear(model_dim, self.number_text_tokens)
self.mel_head = bnb.nn.Linear8bitLt(model_dim, self.number_mel_codes) self.mel_head = ml.Linear(model_dim, self.number_mel_codes)
# Initialize the embeddings per the GPT-2 scheme # Initialize the embeddings per the GPT-2 scheme
embeddings = [self.text_embedding] embeddings = [self.text_embedding]

View File

@ -1,7 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import bitsandbytes as bnb import torch_intermediary as ml
from transformers import GPT2Config, GPT2PreTrainedModel from transformers import GPT2Config, GPT2PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
@ -274,16 +274,16 @@ class UnifiedVoice(nn.Module):
self.mel_length_compression = mel_length_compression self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
# nn.Embedding # nn.Embedding
self.text_embedding = bnb.nn.StableEmbedding(self.number_text_tokens*types+1, model_dim) self.text_embedding = ml.Embedding(self.number_text_tokens*types+1, model_dim)
# nn.Embedding # nn.Embedding
self.mel_embedding = bnb.nn.StableEmbedding(self.number_mel_codes, model_dim) self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim)
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \ self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing) build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing)
self.final_norm = nn.LayerNorm(model_dim) self.final_norm = nn.LayerNorm(model_dim)
self.text_head = bnb.nn.Linear8bitLt(model_dim, self.number_text_tokens*types+1) self.text_head = ml.Linear(model_dim, self.number_text_tokens*types+1)
self.mel_head = bnb.nn.Linear8bitLt(model_dim, self.number_mel_codes) self.mel_head = ml.Linear(model_dim, self.number_mel_codes)
self.aligned_head = bnb.nn.Linear8bitLt(model_dim, number_aligned_text_codes) self.aligned_head = ml.Linear(model_dim, number_aligned_text_codes)
# Initialize the embeddings per the GPT-2 scheme # Initialize the embeddings per the GPT-2 scheme
embeddings = [self.text_embedding, self.mel_embedding] embeddings = [self.text_embedding, self.mel_embedding]

View File

@ -11,7 +11,7 @@ from models.audio.tts.transformer_builders import build_hf_gpt_transformer
from models.lucidrains.x_transformers import RotaryEmbedding, apply_rotary_pos_emb from models.lucidrains.x_transformers import RotaryEmbedding, apply_rotary_pos_emb
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import opt_get from utils.util import opt_get
import bitsandbytes as bnb import torch_intermediary as ml
class ResBlock(nn.Module): class ResBlock(nn.Module):
@ -257,16 +257,16 @@ class UnifiedVoice(nn.Module):
self.mel_length_compression = mel_length_compression self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
# nn.Embedding # nn.Embedding
self.text_embedding = bnb.nn.StableEmbedding(self.number_text_tokens*types+1, model_dim) self.text_embedding = ml.Embedding(self.number_text_tokens*types+1, model_dim)
# nn.Embedding # nn.Embedding
self.mel_embedding = bnb.nn.StableEmbedding(self.number_mel_codes, model_dim) self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim)
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \ self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing) build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing)
self.final_norm = nn.LayerNorm(model_dim) self.final_norm = nn.LayerNorm(model_dim)
self.text_head = bnb.nn.Linear8bitLt(model_dim, self.number_text_tokens*types+1) self.text_head = ml.Linear(model_dim, self.number_text_tokens*types+1)
self.mel_head = bnb.nn.Linear8bitLt(model_dim, self.number_mel_codes) self.mel_head = ml.Linear(model_dim, self.number_mel_codes)
self.alignment_head = bnb.nn.Linear8bitLt(model_dim, 256) self.alignment_head = ml.Linear(model_dim, 256)
if only_alignment_head: if only_alignment_head:
for p in self.parameters(): for p in self.parameters():

View File

@ -8,7 +8,7 @@ from models.audio.tts.mini_encoder import AudioMiniEncoder
from trainer.injectors.spec_augment import spec_augment from trainer.injectors.spec_augment import spec_augment
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import opt_get from utils.util import opt_get
import bitsandbytes as bnb import torch_intermediary as ml
def exists(val): def exists(val):
@ -37,7 +37,7 @@ class VoiceCLIP(nn.Module):
self.encoder = AudioMiniEncoder(80, encoder_output) self.encoder = AudioMiniEncoder(80, encoder_output)
if pretrained_encoder_dict_path is not None: if pretrained_encoder_dict_path is not None:
self.encoder.load_state_dict(torch.load(pretrained_encoder_dict_path)) self.encoder.load_state_dict(torch.load(pretrained_encoder_dict_path))
self.to_latent = bnb.nn.Linear8bitLt(encoder_output, dim_latent, bias=False) self.to_latent = ml.Linear(encoder_output, dim_latent, bias=False)
self.temperature = nn.Parameter(torch.tensor(1.)) self.temperature = nn.Parameter(torch.tensor(1.))
self.mel_compression_ratio = mel_compression_ratio self.mel_compression_ratio = mel_compression_ratio

View File

@ -7,7 +7,7 @@ from x_transformers import Encoder, Decoder, ContinuousTransformerWrapper
from models.audio.tts.mini_encoder import AudioMiniEncoder from models.audio.tts.mini_encoder import AudioMiniEncoder
from trainer.networks import register_model from trainer.networks import register_model
import bitsandbytes as bnb import torch_intermediary as ml
class CheckpointedLayer(nn.Module): class CheckpointedLayer(nn.Module):
@ -58,7 +58,7 @@ class Wav2VecMatcher(nn.Module):
self.conditioning_encoder = AudioMiniEncoder(1, model_dim, base_channels=32, depth=6, resnet_blocks=1, self.conditioning_encoder = AudioMiniEncoder(1, model_dim, base_channels=32, depth=6, resnet_blocks=1,
attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5) attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
# nn.Embedding # nn.Embedding
self.text_embedding = bnb.nn.StableEmbedding(num_text_tokens, model_dim) self.text_embedding = ml.Embedding(num_text_tokens, model_dim)
self.encoder = CheckpointedXTransformer( self.encoder = CheckpointedXTransformer(
max_seq_len=-1, max_seq_len=-1,
use_pos_emb=False, use_pos_emb=False,
@ -75,8 +75,8 @@ class Wav2VecMatcher(nn.Module):
) )
self.decoder_start_embedding = nn.Parameter(torch.randn(1,1,model_dim)) self.decoder_start_embedding = nn.Parameter(torch.randn(1,1,model_dim))
self.decoder_stop_embedding = nn.Parameter(torch.randn(1,model_dim)) self.decoder_stop_embedding = nn.Parameter(torch.randn(1,model_dim))
self.w2v_query_encoder = bnb.nn.Linear8bitLt(WAV2VEC_CHANNELS, model_dim) self.w2v_query_encoder = ml.Linear(WAV2VEC_CHANNELS, model_dim)
self.w2v_value_encoder = bnb.nn.Linear8bitLt(WAV2VEC_CHANNELS, model_dim) self.w2v_value_encoder = ml.Linear(WAV2VEC_CHANNELS, model_dim)
self.decoder = CheckpointedXTransformer( self.decoder = CheckpointedXTransformer(
max_seq_len=-1, # Should be unused max_seq_len=-1, # Should be unused
use_pos_emb=False, use_pos_emb=False,

View File

@ -10,7 +10,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import bitsandbytes as bnb import torch_intermediary as ml
from trainer.networks import register_model from trainer.networks import register_model
@ -99,7 +99,7 @@ class ResNet(nn.Module):
self.conv4_x = self._make_layer(block, 128, num_block[2], 2) self.conv4_x = self._make_layer(block, 128, num_block[2], 2)
self.conv5_x = self._make_layer(block, 256, num_block[3], 2) self.conv5_x = self._make_layer(block, 256, num_block[3], 2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = bnb.nn.Linear8bitLt(256 * block.expansion, num_classes) self.fc = ml.Linear(256 * block.expansion, num_classes)
def _make_layer(self, block, out_channels, num_blocks, stride): def _make_layer(self, block, out_channels, num_blocks, stride):
"""make resnet layers(by layer i didnt mean this 'layer' was the """make resnet layers(by layer i didnt mean this 'layer' was the

View File

@ -4,7 +4,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torchvision.models.resnet import BasicBlock, Bottleneck from torchvision.models.resnet import BasicBlock, Bottleneck
import torchvision import torchvision
import bitsandbytes as bnb import torch_intermediary as ml
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
@ -195,5 +195,5 @@ def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
def register_resnet50(opt_net, opt): def register_resnet50(opt_net, opt):
model = resnet50(pretrained=opt_net['pretrained']) model = resnet50(pretrained=opt_net['pretrained'])
if opt_net['custom_head_logits']: if opt_net['custom_head_logits']:
model.fc = bnb.nn.Linear8bitLt(512 * 4, opt_net['custom_head_logits']) model.fc = ml.Linear(512 * 4, opt_net['custom_head_logits'])
return model return model

View File

@ -11,7 +11,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import bitsandbytes as bnb import torch_intermediary as ml
from trainer.networks import register_model from trainer.networks import register_model
@ -102,7 +102,7 @@ class ResNet(nn.Module):
self.conv4_x = self._make_layer(block, 128, num_block[2], 2) self.conv4_x = self._make_layer(block, 128, num_block[2], 2)
self.conv5_x = self._make_layer(block, 256, num_block[3], 2) self.conv5_x = self._make_layer(block, 256, num_block[3], 2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = bnb.nn.Linear8bitLt(256 * block.expansion, num_classes) self.fc = ml.Linear(256 * block.expansion, num_classes)
def _make_layer(self, block, out_channels, num_blocks, stride): def _make_layer(self, block, out_channels, num_blocks, stride):
"""make resnet layers(by layer i didnt mean this 'layer' was the """make resnet layers(by layer i didnt mean this 'layer' was the

View File

@ -11,7 +11,7 @@ __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
from models.vqvae.scaled_weight_conv import ScaledWeightConv from models.vqvae.scaled_weight_conv import ScaledWeightConv
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import checkpoint from utils.util import checkpoint
import bitsandbytes as bnb import torch_intermediary as ml
model_urls = { model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
@ -214,7 +214,7 @@ class ResNet(nn.Module):
self.layer4 = self._make_layer(block, 512, layers[3], breadth, stride=2, self.layer4 = self._make_layer(block, 512, layers[3], breadth, stride=2,
dilate=replace_stride_with_dilation[2]) dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = bnb.nn.Linear8bitLt(512 * block.expansion, num_classes) self.fc = ml.Linear(512 * block.expansion, num_classes)
for m in self.modules(): for m in self.modules():
if isinstance(m, ScaledWeightConv): if isinstance(m, ScaledWeightConv):

View File

@ -3,7 +3,7 @@ import torch.nn as nn
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import opt_get from utils.util import opt_get
import bitsandbytes as bnb import torch_intermediary as ml
class WideKernelVgg(nn.Module): class WideKernelVgg(nn.Module):
def __init__(self, nf=64, num_classes=2): def __init__(self, nf=64, num_classes=2):
@ -49,9 +49,9 @@ class WideKernelVgg(nn.Module):
nn.ReLU(), nn.ReLU(),
nn.MaxPool2d(kernel_size=2), nn.MaxPool2d(kernel_size=2),
nn.Flatten(), nn.Flatten(),
bnb.nn.Linear8bitLt(nf * 8 * 4 * 2, 100), ml.Linear(nf * 8 * 4 * 2, 100),
nn.ReLU(), nn.ReLU(),
bnb.nn.Linear8bitLt(100, num_classes) ml.Linear(100, num_classes)
) )
# These normalization constants should be derived experimentally. # These normalization constants should be derived experimentally.

View File

@ -10,7 +10,7 @@ from models.arch_util import AttentionBlock
from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import opt_get, checkpoint from utils.util import opt_get, checkpoint
import bitsandbytes as bnb import torch_intermediary as ml
def exists(val): def exists(val):
@ -60,7 +60,7 @@ class ConvFormatEmbedding(nn.Module):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__() super().__init__()
# nn.Embedding # nn.Embedding
self.emb = bnb.nn.StableEmbedding(*args, **kwargs) self.emb = ml.Embedding(*args, **kwargs)
def forward(self, x): def forward(self, x):
y = self.emb(x) y = self.emb(x)
@ -101,9 +101,9 @@ class CLVP(nn.Module):
self.mask_conditioning_percentage = mask_conditioning_percentage self.mask_conditioning_percentage = mask_conditioning_percentage
# nn.Embedding # nn.Embedding
self.text_emb = bnb.nn.StableEmbedding(num_text_tokens, model_dim) self.text_emb = ml.Embedding(num_text_tokens, model_dim)
self.text_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, text_enc_depth, text_mask_percentage, use_rms_scaleshift_norm=True) self.text_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, text_enc_depth, text_mask_percentage, use_rms_scaleshift_norm=True)
self.to_text_latent = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False) self.to_text_latent = ml.Linear(latent_dim, latent_dim, bias=False)
self.distributed_collect = distributed_collect self.distributed_collect = distributed_collect
if mel_codes is None: if mel_codes is None:
@ -111,7 +111,7 @@ class CLVP(nn.Module):
else: else:
self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim) self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim)
self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage) self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage)
self.to_speech_latent = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False) self.to_speech_latent = ml.Linear(latent_dim, latent_dim, bias=False)
def get_grad_norm_parameter_groups(self): def get_grad_norm_parameter_groups(self):
return { return {

View File

@ -9,7 +9,7 @@ from models.arch_util import AttentionBlock
from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import opt_get, checkpoint from utils.util import opt_get, checkpoint
import bitsandbytes as bnb import torch_intermediary as ml
def exists(val): def exists(val):
@ -180,7 +180,7 @@ class ConvFormatEmbedding(nn.Module):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__() super().__init__()
# nn.Embedding # nn.Embedding
self.emb = bnb.nn.StableEmbedding(*args, **kwargs) self.emb = ml.Embedding(*args, **kwargs)
def forward(self, x): def forward(self, x):
y = self.emb(x) y = self.emb(x)
@ -205,8 +205,8 @@ class ContrastiveAudio(nn.Module):
self.emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim // 2, kernel_size=5, stride=2, padding=2), self.emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim // 2, kernel_size=5, stride=2, padding=2),
nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1)) nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1))
self.transformer = CollapsingTransformer(model_dim, model_dim, transformer_heads, dropout, encoder_depth, mask_percent) self.transformer = CollapsingTransformer(model_dim, model_dim, transformer_heads, dropout, encoder_depth, mask_percent)
self.to_latent = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False) self.to_latent = ml.Linear(latent_dim, latent_dim, bias=False)
self.to_latent2 = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False) self.to_latent2 = ml.Linear(latent_dim, latent_dim, bias=False)
self.to_latent2.weight.data = self.to_latent.weight.data self.to_latent2.weight.data = self.to_latent.weight.data
self.to_latent2.weight.DO_NOT_TRAIN = True self.to_latent2.weight.DO_NOT_TRAIN = True

View File

@ -10,7 +10,7 @@ from models.arch_util import AttentionBlock
from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import opt_get, checkpoint from utils.util import opt_get, checkpoint
import bitsandbytes as bnb import torch_intermediary as ml
def exists(val): def exists(val):
@ -60,7 +60,7 @@ class ConvFormatEmbedding(nn.Module):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__() super().__init__()
# nn.Embedding # nn.Embedding
self.emb = bnb.nn.StableEmbedding(*args, **kwargs) self.emb = ml.Embedding(*args, **kwargs)
def forward(self, x): def forward(self, x):
y = self.emb(x) y = self.emb(x)
@ -88,14 +88,14 @@ class CVVP(nn.Module):
self.cond_emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim//2, kernel_size=5, stride=2, padding=2), self.cond_emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim//2, kernel_size=5, stride=2, padding=2),
nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1)) nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1))
self.conditioning_transformer = CollapsingTransformer(model_dim, model_dim, transformer_heads, dropout, conditioning_enc_depth, cond_mask_percentage) self.conditioning_transformer = CollapsingTransformer(model_dim, model_dim, transformer_heads, dropout, conditioning_enc_depth, cond_mask_percentage)
self.to_conditioning_latent = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False) self.to_conditioning_latent = ml.Linear(latent_dim, latent_dim, bias=False)
if mel_codes is None: if mel_codes is None:
self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2) self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2)
else: else:
self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim) self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim)
self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage) self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage)
self.to_speech_latent = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False) self.to_speech_latent = ml.Linear(latent_dim, latent_dim, bias=False)
def get_grad_norm_parameter_groups(self): def get_grad_norm_parameter_groups(self):
return { return {

View File

@ -7,7 +7,7 @@ from torch import einsum
from models.lucidrains.dalle.transformer import Transformer from models.lucidrains.dalle.transformer import Transformer
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import opt_get from utils.util import opt_get
import bitsandbytes as bnb import torch_intermediary as ml
def exists(val): def exists(val):
@ -47,19 +47,19 @@ class MelTextCLIP(nn.Module):
): ):
super().__init__() super().__init__()
# nn.Embedding # nn.Embedding
self.text_emb = bnb.nn.StableEmbedding(num_text_tokens, dim_text) self.text_emb = ml.Embedding(num_text_tokens, dim_text)
# nn.Embedding # nn.Embedding
self.text_pos_emb = bnb.nn.StableEmbedding(text_seq_len, dim_text) self.text_pos_emb = ml.Embedding(text_seq_len, dim_text)
self.text_transformer = Transformer(causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth, self.text_transformer = Transformer(causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth,
heads=text_heads, rotary_emb=False) heads=text_heads, rotary_emb=False)
self.to_text_latent = bnb.nn.Linear8bitLt(dim_text, dim_latent, bias=False) self.to_text_latent = ml.Linear(dim_text, dim_latent, bias=False)
self.speech_enc = nn.Conv1d(80, dim_speech, kernel_size=3, padding=1) self.speech_enc = nn.Conv1d(80, dim_speech, kernel_size=3, padding=1)
# nn.Embedding # nn.Embedding
self.speech_pos_emb = bnb.nn.StableEmbedding(num_speech_tokens, dim_speech) self.speech_pos_emb = ml.Embedding(num_speech_tokens, dim_speech)
self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech, self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech,
depth=speech_enc_depth, heads=speech_heads, rotary_emb=False) depth=speech_enc_depth, heads=speech_heads, rotary_emb=False)
self.to_speech_latent = bnb.nn.Linear8bitLt(dim_speech, dim_latent, bias=False) self.to_speech_latent = ml.Linear(dim_speech, dim_latent, bias=False)
self.temperature = nn.Parameter(torch.tensor(1.)) self.temperature = nn.Parameter(torch.tensor(1.))
self.text_mask_percentage = text_mask_percentage self.text_mask_percentage = text_mask_percentage

View File

@ -7,7 +7,7 @@ from models.audio.tts.unified_voice2 import ConditioningEncoder
from models.lucidrains.dalle.transformer import Transformer from models.lucidrains.dalle.transformer import Transformer
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import opt_get from utils.util import opt_get
import bitsandbytes as bnb import torch_intermediary as ml
def exists(val): def exists(val):
@ -46,7 +46,7 @@ class VoiceCondCLIP(nn.Module):
self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech) self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech, self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech,
depth=speech_enc_depth, heads=speech_heads, rotary_emb=False) depth=speech_enc_depth, heads=speech_heads, rotary_emb=False)
self.to_speech_latent = bnb.nn.Linear8bitLt(dim_speech, dim_latent, bias=False) self.to_speech_latent = ml.Linear(dim_speech, dim_latent, bias=False)
self.temperature = nn.Parameter(torch.tensor(1.)) self.temperature = nn.Parameter(torch.tensor(1.))
self.voice_mask_percentage = voice_mask_percentage self.voice_mask_percentage = voice_mask_percentage

View File

@ -11,7 +11,7 @@ from models.audio.tts.unet_diffusion_tts7 import CheckpointedXTransformerEncoder
from models.lucidrains.dalle.transformer import Transformer from models.lucidrains.dalle.transformer import Transformer
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import opt_get from utils.util import opt_get
import bitsandbytes as bnb import torch_intermediary as ml
def exists(val): def exists(val):
@ -55,12 +55,12 @@ class VoiceCLIP(nn.Module):
): ):
super().__init__() super().__init__()
# nn.Embedding # nn.Embedding
self.text_emb = bnb.nn.StableEmbedding(num_text_tokens, dim_text) self.text_emb = ml.Embedding(num_text_tokens, dim_text)
self.to_text_latent = bnb.nn.Linear8bitLt(dim_text, dim_latent, bias=False) self.to_text_latent = ml.Linear(dim_text, dim_latent, bias=False)
# nn.Embedding # nn.Embedding
self.speech_emb = bnb.nn.StableEmbedding(num_speech_tokens, dim_speech) self.speech_emb = ml.Embedding(num_speech_tokens, dim_speech)
self.to_speech_latent = bnb.nn.Linear8bitLt(dim_speech, dim_latent, bias=False) self.to_speech_latent = ml.Linear(dim_speech, dim_latent, bias=False)
if use_xformers: if use_xformers:
self.text_transformer = CheckpointedXTransformerEncoder( self.text_transformer = CheckpointedXTransformerEncoder(
@ -109,9 +109,9 @@ class VoiceCLIP(nn.Module):
self.distributed_collect = distributed_collect self.distributed_collect = distributed_collect
if not use_xformers: if not use_xformers:
# nn.Embedding # nn.Embedding
self.text_pos_emb = bnb.nn.StableEmbedding(text_seq_len, dim_text) self.text_pos_emb = ml.Embedding(text_seq_len, dim_text)
# nn.Embedding # nn.Embedding
self.speech_pos_emb = bnb.nn.StableEmbedding(num_speech_tokens, dim_speech) self.speech_pos_emb = ml.Embedding(num_speech_tokens, dim_speech)
def embed_text(self, text): def embed_text(self, text):
text_mask = torch.ones_like(text.float()).bool() text_mask = torch.ones_like(text.float()).bool()

View File

@ -6,7 +6,7 @@ import math
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import bitsandbytes as bnb import torch_intermediary as ml
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
@ -37,7 +37,7 @@ def linear(*args, **kwargs):
""" """
Create a linear module. Create a linear module.
""" """
return bnb.nn.Linear8bitLt(*args, **kwargs) return ml.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs): def avg_pool_nd(dims, *args, **kwargs):

View File

@ -6,7 +6,7 @@ from models.arch_util import ConvGnLelu, default_init_weights, make_layer
from models.diffusion.nn import timestep_embedding from models.diffusion.nn import timestep_embedding
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import checkpoint from utils.util import checkpoint
import bitsandbytes as bnb import torch_intermediary as ml
# Conditionally uses torch's checkpoint functionality if it is enabled in the opt file. # Conditionally uses torch's checkpoint functionality if it is enabled in the opt file.
@ -29,7 +29,7 @@ class ResidualDenseBlock(nn.Module):
self.first_conv = ConvGnLelu(mid_channels, mid_channels, activation=True, norm=False, bias=True) self.first_conv = ConvGnLelu(mid_channels, mid_channels, activation=True, norm=False, bias=True)
self.emb_layers = nn.Sequential( self.emb_layers = nn.Sequential(
nn.SiLU(), nn.SiLU(),
bnb.nn.Linear8bitLt( ml.Linear(
mid_channels*4, mid_channels*4,
mid_channels, mid_channels,
), ),
@ -144,9 +144,9 @@ class RRDBNet(nn.Module):
# Guided diffusion uses a time embedding. # Guided diffusion uses a time embedding.
time_embed_dim = mid_channels * 4 time_embed_dim = mid_channels * 4
self.time_embed = nn.Sequential( self.time_embed = nn.Sequential(
bnb.nn.Linear8bitLt(mid_channels, time_embed_dim), ml.Linear(mid_channels, time_embed_dim),
nn.SiLU(), nn.SiLU(),
bnb.nn.Linear8bitLt(time_embed_dim, time_embed_dim), ml.Linear(time_embed_dim, time_embed_dim),
) )
self.body = make_layer( self.body = make_layer(

View File

@ -20,7 +20,7 @@ from models.diffusion.nn import (
) )
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import checkpoint from utils.util import checkpoint
import bitsandbytes as bnb import torch_intermediary as ml
class AttentionPool2d(nn.Module): class AttentionPool2d(nn.Module):
@ -517,7 +517,7 @@ class UNetModel(nn.Module):
if self.num_classes is not None: if self.num_classes is not None:
# nn.Embedding # nn.Embedding
self.label_emb = bnb.nn.StableEmbedding(num_classes, time_embed_dim) self.label_emb = ml.Embedding(num_classes, time_embed_dim)
self.use_raw_y_as_embedding = use_raw_y_as_embedding self.use_raw_y_as_embedding = use_raw_y_as_embedding
assert not ((self.num_classes is not None) and use_raw_y_as_embedding) # These are mutually-exclusive. assert not ((self.num_classes is not None) and use_raw_y_as_embedding) # These are mutually-exclusive.
@ -869,16 +869,16 @@ class EncoderUNetModel(nn.Module):
) )
elif pool == "spatial": elif pool == "spatial":
self.out = nn.Sequential( self.out = nn.Sequential(
bnb.nn.Linear8bitLt(self._feature_size, 2048), ml.Linear(self._feature_size, 2048),
nn.ReLU(), nn.ReLU(),
bnb.nn.Linear8bitLt(2048, self.out_channels), ml.Linear(2048, self.out_channels),
) )
elif pool == "spatial_v2": elif pool == "spatial_v2":
self.out = nn.Sequential( self.out = nn.Sequential(
bnb.nn.Linear8bitLt(self._feature_size, 2048), ml.Linear(self._feature_size, 2048),
normalization(2048), normalization(2048),
nn.SiLU(), nn.SiLU(),
bnb.nn.Linear8bitLt(2048, self.out_channels), ml.Linear(2048, self.out_channels),
) )
else: else:
raise NotImplementedError(f"Unexpected {pool} pooling") raise NotImplementedError(f"Unexpected {pool} pooling")

View File

@ -26,7 +26,7 @@ from models.diffusion.nn import (
) )
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import checkpoint from utils.util import checkpoint
import bitsandbytes as bnb import torch_intermediary as ml
class AttentionPool2d(nn.Module): class AttentionPool2d(nn.Module):
@ -478,7 +478,7 @@ class UNetModel(nn.Module):
if self.num_classes is not None: if self.num_classes is not None:
# nn.Embedding # nn.Embedding
self.label_emb = bnb.nn.StableEmbedding(num_classes, time_embed_dim) self.label_emb = ml.Embedding(num_classes, time_embed_dim)
self.input_blocks = nn.ModuleList( self.input_blocks = nn.ModuleList(
[ [
@ -738,7 +738,7 @@ class ResNetEncoder(nn.Module):
dilate=replace_stride_with_dilation[2]) dilate=replace_stride_with_dilation[2])
f=512 f=512
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = bnb.nn.Linear8bitLt(f * block.expansion, output_dim) self.fc = ml.Linear(f * block.expansion, output_dim)
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):

View File

@ -5,7 +5,7 @@ import torch.nn.functional as F
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import checkpoint, opt_get from utils.util import checkpoint, opt_get
import bitsandbytes as bnb import torch_intermediary as ml
class Discriminator_VGG_128(nn.Module): class Discriminator_VGG_128(nn.Module):
@ -47,8 +47,8 @@ class Discriminator_VGG_128(nn.Module):
input_img_factor = input_img_factor // 2 input_img_factor = input_img_factor // 2
final_nf = nf * 16 final_nf = nf * 16
self.linear1 = bnb.nn.Linear8bitLt(final_nf * 4 * input_img_factor * 4 * input_img_factor, 100) self.linear1 = ml.Linear(final_nf * 4 * input_img_factor * 4 * input_img_factor, 100)
self.linear2 = bnb.nn.Linear8bitLt(100, 1) self.linear2 = ml.Linear(100, 1)
# activation function # activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
@ -130,8 +130,8 @@ class Discriminator_VGG_128_GN(nn.Module):
# activation function # activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.linear1 = bnb.nn.Linear8bitLt(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 100) self.linear1 = ml.Linear(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 100)
self.linear2 = bnb.nn.Linear8bitLt(100, 1) self.linear2 = ml.Linear(100, 1)
def compute_body(self, x): def compute_body(self, x):
fea = self.lrelu(self.conv0_0(x)) fea = self.lrelu(self.conv0_0(x))
@ -220,8 +220,8 @@ class DiscriminatorVGG448GN(nn.Module):
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
final_nf = nf * 8 final_nf = nf * 8
self.linear1 = bnb.nn.Linear8bitLt(int(final_nf * 7 * 7), 100) self.linear1 = ml.Linear(int(final_nf * 7 * 7), 100)
self.linear2 = bnb.nn.Linear8bitLt(100, 1) self.linear2 = ml.Linear(100, 1)
# Assign all new heads to the new param group.2 # Assign all new heads to the new param group.2
for m in [self.convn1_0, self.convn1_1, self.bnn1_1, self.conv0_0_new, self.bn0_0]: for m in [self.convn1_0, self.convn1_1, self.bnn1_1, self.conv0_0_new, self.bn0_0]:

View File

@ -2,7 +2,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.init as init import torch.nn.init as init
import torch.nn.functional as F import torch.nn.functional as F
import bitsandbytes as bnb import torch_intermediary as ml
def initialize_weights(net_l, scale=1): def initialize_weights(net_l, scale=1):
@ -15,7 +15,7 @@ def initialize_weights(net_l, scale=1):
m.weight.data *= scale # for residual block m.weight.data *= scale # for residual block
if m.bias is not None: if m.bias is not None:
m.bias.data.zero_() m.bias.data.zero_()
elif isinstance(m, bnb.nn.Linear8bitLt): elif isinstance(m, ml.Linear):
init.kaiming_normal_(m.weight, a=0, mode='fan_in') init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale m.weight.data *= scale
if m.bias is not None: if m.bias is not None:

View File

@ -28,7 +28,7 @@ except:
APEX_AVAILABLE = False APEX_AVAILABLE = False
assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.' assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.'
import bitsandbytes as bnb import torch_intermediary as ml
num_cores = multiprocessing.cpu_count() num_cores = multiprocessing.cpu_count()
@ -352,7 +352,7 @@ class RGBBlock(nn.Module):
def __init__(self, latent_dim, input_channel, upsample, rgba=False): def __init__(self, latent_dim, input_channel, upsample, rgba=False):
super().__init__() super().__init__()
self.input_channel = input_channel self.input_channel = input_channel
self.to_style = bnb.nn.Linear8bitLt(latent_dim, input_channel) self.to_style = ml.Linear(latent_dim, input_channel)
out_filters = 3 if not rgba else 4 out_filters = 3 if not rgba else 4
self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False) self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False)
@ -490,16 +490,16 @@ class GeneratorBlockWithStructure(nn.Module):
# Uses stylegan1 style blocks for injecting structural latent. # Uses stylegan1 style blocks for injecting structural latent.
self.conv0 = EqualConv2d(input_channels, filters, 3, padding=1) self.conv0 = EqualConv2d(input_channels, filters, 3, padding=1)
self.to_noise0 = bnb.nn.Linear8bitLt(1, filters) self.to_noise0 = ml.Linear(1, filters)
self.noise0 = equal_lr(NoiseInjection(filters)) self.noise0 = equal_lr(NoiseInjection(filters))
self.adain0 = AdaptiveInstanceNorm(filters, latent_dim) self.adain0 = AdaptiveInstanceNorm(filters, latent_dim)
self.to_style1 = bnb.nn.Linear8bitLt(latent_dim, filters) self.to_style1 = ml.Linear(latent_dim, filters)
self.to_noise1 = bnb.nn.Linear8bitLt(1, filters) self.to_noise1 = ml.Linear(1, filters)
self.conv1 = Conv2DMod(filters, filters, 3) self.conv1 = Conv2DMod(filters, filters, 3)
self.to_style2 = bnb.nn.Linear8bitLt(latent_dim, filters) self.to_style2 = ml.Linear(latent_dim, filters)
self.to_noise2 = bnb.nn.Linear8bitLt(1, filters) self.to_noise2 = ml.Linear(1, filters)
self.conv2 = Conv2DMod(filters, filters, 3) self.conv2 = Conv2DMod(filters, filters, 3)
self.activation = leaky_relu() self.activation = leaky_relu()
@ -541,12 +541,12 @@ class GeneratorBlock(nn.Module):
self.structure_conv = nn.Conv2d(3, input_channels, 3, padding=1) self.structure_conv = nn.Conv2d(3, input_channels, 3, padding=1)
input_channels = input_channels * 2 input_channels = input_channels * 2
self.to_style1 = bnb.nn.Linear8bitLt(latent_dim, input_channels) self.to_style1 = ml.Linear(latent_dim, input_channels)
self.to_noise1 = bnb.nn.Linear8bitLt(1, filters) self.to_noise1 = ml.Linear(1, filters)
self.conv1 = Conv2DMod(input_channels, filters, 3) self.conv1 = Conv2DMod(input_channels, filters, 3)
self.to_style2 = bnb.nn.Linear8bitLt(latent_dim, filters) self.to_style2 = ml.Linear(latent_dim, filters)
self.to_noise2 = bnb.nn.Linear8bitLt(1, filters) self.to_noise2 = ml.Linear(1, filters)
self.conv2 = Conv2DMod(filters, filters, 3) self.conv2 = Conv2DMod(filters, filters, 3)
self.activation = leaky_relu() self.activation = leaky_relu()
@ -725,7 +725,7 @@ class StyleGan2GeneratorWithLatent(nn.Module):
def _init_weights(self): def _init_weights(self):
for m in self.modules(): for m in self.modules():
if type(m) in {nn.Conv2d, bnb.nn.Linear8bitLt} and hasattr(m, 'weight'): if type(m) in {nn.Conv2d, ml.Linear} and hasattr(m, 'weight'):
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
for block in self.gen.blocks: for block in self.gen.blocks:
@ -805,7 +805,7 @@ class StyleGan2Discriminator(nn.Module):
self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1) self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1)
self.flatten = Flatten() self.flatten = Flatten()
self.to_logit = bnb.nn.Linear8bitLt(latent_dim, 1) self.to_logit = ml.Linear(latent_dim, 1)
self._init_weights() self._init_weights()
@ -837,7 +837,7 @@ class StyleGan2Discriminator(nn.Module):
def _init_weights(self): def _init_weights(self):
for m in self.modules(): for m in self.modules():
if type(m) in {nn.Conv2d, bnb.nn.Linear8bitLt}: if type(m) in {nn.Conv2d, ml.Linear}:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')

View File

@ -12,7 +12,7 @@ from torch import nn
from data.images.byol_attachment import RandomApply from data.images.byol_attachment import RandomApply
from trainer.networks import register_model, create_model from trainer.networks import register_model, create_model
from utils.util import checkpoint, opt_get from utils.util import checkpoint, opt_get
import bitsandbytes as bnb import torch_intermediary as ml
def default(val, def_val): def default(val, def_val):
@ -79,10 +79,10 @@ class MLP(nn.Module):
def __init__(self, dim, projection_size, hidden_size=4096): def __init__(self, dim, projection_size, hidden_size=4096):
super().__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
bnb.nn.Linear8bitLt(dim, hidden_size), ml.Linear(dim, hidden_size),
nn.BatchNorm1d(hidden_size), nn.BatchNorm1d(hidden_size),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
bnb.nn.Linear8bitLt(hidden_size, projection_size) ml.Linear(hidden_size, projection_size)
) )
def forward(self, x): def forward(self, x):
@ -104,10 +104,10 @@ class StructuralMLP(nn.Module):
nn.BatchNorm2d(c), nn.BatchNorm2d(c),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Flatten(), nn.Flatten(),
bnb.nn.Linear8bitLt(flattened_dim, hidden_size), ml.Linear(flattened_dim, hidden_size),
nn.BatchNorm1d(hidden_size), nn.BatchNorm1d(hidden_size),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
bnb.nn.Linear8bitLt(hidden_size, projection_size) ml.Linear(hidden_size, projection_size)
) )
def forward(self, x): def forward(self, x):

View File

@ -1,7 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
import bitsandbytes as bnb import torch_intermediary as ml
__all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152'] __all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152']
@ -109,8 +109,8 @@ class FixupResNet(nn.Module):
self.layer4 = self._make_layer(block, num_filters*8, layers[3], stride=2) self.layer4 = self._make_layer(block, num_filters*8, layers[3], stride=2)
self.bias2 = nn.Parameter(torch.zeros(1)) self.bias2 = nn.Parameter(torch.zeros(1))
reduced_img_sz = int(input_img_size / 32) reduced_img_sz = int(input_img_size / 32)
self.fc1 = bnb.nn.Linear8bitLt(num_filters * 8 * reduced_img_sz * reduced_img_sz, 100) self.fc1 = ml.Linear(num_filters * 8 * reduced_img_sz * reduced_img_sz, 100)
self.fc2 = bnb.nn.Linear8bitLt(100, num_classes) self.fc2 = ml.Linear(100, num_classes)
for m in self.modules(): for m in self.modules():
if isinstance(m, FixupBasicBlock): if isinstance(m, FixupBasicBlock):
@ -125,7 +125,7 @@ class FixupResNet(nn.Module):
if m.downsample is not None: if m.downsample is not None:
nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:])))) nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:]))))
''' '''
elif isinstance(m, bnb.nn.Linear8bitLt): elif isinstance(m, ml.Linear):
nn.init.constant_(m.weight, 0) nn.init.constant_(m.weight, 0)
nn.init.constant_(m.bias, 0)''' nn.init.constant_(m.bias, 0)'''

View File

@ -5,7 +5,7 @@ import torch.nn.functional as F
from models.arch_util import ResBlock from models.arch_util import ResBlock
from models.lucidrains.x_transformers import Encoder from models.lucidrains.x_transformers import Encoder
from trainer.networks import register_model from trainer.networks import register_model
import bitsandbytes as bnb import torch_intermediary as ml
class VitLatent(nn.Module): class VitLatent(nn.Module):
@ -32,10 +32,10 @@ class VitLatent(nn.Module):
do_checkpointing=True do_checkpointing=True
) )
self.mlp = nn.Sequential(bnb.nn.Linear8bitLt(hidden_dim, hidden_dim*2), self.mlp = nn.Sequential(ml.Linear(hidden_dim, hidden_dim*2),
nn.BatchNorm1d(hidden_dim*2), nn.BatchNorm1d(hidden_dim*2),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
bnb.nn.Linear8bitLt(hidden_dim*2, hidden_dim)) ml.Linear(hidden_dim*2, hidden_dim))
def provide_ema(self, ema): def provide_ema(self, ema):
self.ema = ema self.ema = ema

View File

@ -7,7 +7,7 @@ import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from rotary_embedding_torch import apply_rotary_emb from rotary_embedding_torch import apply_rotary_emb
import bitsandbytes as bnb import torch_intermediary as ml
# helpers # helpers
@ -48,9 +48,9 @@ class Attention(nn.Module):
self.stable = stable self.stable = stable
self.causal = causal self.causal = causal
self.to_qkv = bnb.nn.Linear8bitLt(dim, inner_dim * 3, bias = False) self.to_qkv = ml.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential( self.to_out = nn.Sequential(
bnb.nn.Linear8bitLt(inner_dim, dim), ml.Linear(inner_dim, dim),
nn.Dropout(dropout) nn.Dropout(dropout)
) )
@ -103,10 +103,10 @@ class SparseConvCausalAttention(nn.Module):
self.stable = stable self.stable = stable
self.to_qkv = bnb.nn.Linear8bitLt(dim, inner_dim * 3, bias = False) self.to_qkv = ml.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential( self.to_out = nn.Sequential(
bnb.nn.Linear8bitLt(inner_dim, dim), ml.Linear(inner_dim, dim),
nn.Dropout(dropout) nn.Dropout(dropout)
) )
@ -223,10 +223,10 @@ class SparseAxialCausalAttention(nn.Module):
self.stable = stable self.stable = stable
self.to_qkv = bnb.nn.Linear8bitLt(dim, inner_dim * 3, bias = False) self.to_qkv = ml.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential( self.to_out = nn.Sequential(
bnb.nn.Linear8bitLt(inner_dim, dim), ml.Linear(inner_dim, dim),
nn.Dropout(dropout) nn.Dropout(dropout)
) )

View File

@ -11,7 +11,7 @@ from models.lucidrains.dalle.attention import Attention, SparseAttention, Sparse
from rotary_embedding_torch import RotaryEmbedding, broadcat from rotary_embedding_torch import RotaryEmbedding, broadcat
from g_mlp_pytorch import gMLPBlock from g_mlp_pytorch import gMLPBlock
import bitsandbytes as bnb import torch_intermediary as ml
# helpers # helpers
@ -79,10 +79,10 @@ class FeedForward(nn.Module):
def __init__(self, dim, dropout = 0., mult = 4.): def __init__(self, dim, dropout = 0., mult = 4.):
super().__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
bnb.nn.Linear8bitLt(dim, dim * mult * 2), ml.Linear(dim, dim * mult * 2),
GEGLU(), GEGLU(),
nn.Dropout(dropout), nn.Dropout(dropout),
bnb.nn.Linear8bitLt(dim * mult, dim) ml.Linear(dim * mult, dim)
) )
def forward(self, x): def forward(self, x):

View File

@ -21,7 +21,7 @@ try:
APEX_AVAILABLE = True APEX_AVAILABLE = True
except: except:
APEX_AVAILABLE = False APEX_AVAILABLE = False
import bitsandbytes as bnb import torch_intermediary as ml
# helpers # helpers
@ -357,10 +357,10 @@ class FeedForward(nn.Module):
activation = default(activation, nn.GELU) activation = default(activation, nn.GELU)
self.glu = glu self.glu = glu
self.w1 = bnb.nn.Linear8bitLt(dim, dim * mult * (2 if glu else 1)) self.w1 = ml.Linear(dim, dim * mult * (2 if glu else 1))
self.act = activation() self.act = activation()
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.w2 = bnb.nn.Linear8bitLt(dim * mult, dim) self.w2 = ml.Linear(dim * mult, dim)
def forward(self, x, **kwargs): def forward(self, x, **kwargs):
if not self.glu: if not self.glu:
@ -402,10 +402,10 @@ class Attention(nn.Module):
self.global_heads = heads - local_heads self.global_heads = heads - local_heads
self.local_attn = LocalAttention(window_size = local_window_size, causal = causal, autopad = True, dropout = dropout, look_forward = int(not causal), rel_pos_emb_config = (dim_head, local_heads)) if local_heads > 0 else None self.local_attn = LocalAttention(window_size = local_window_size, causal = causal, autopad = True, dropout = dropout, look_forward = int(not causal), rel_pos_emb_config = (dim_head, local_heads)) if local_heads > 0 else None
self.to_q = bnb.nn.Linear8bitLt(dim, inner_dim, bias = qkv_bias) self.to_q = ml.Linear(dim, inner_dim, bias = qkv_bias)
self.to_k = bnb.nn.Linear8bitLt(dim, inner_dim, bias = qkv_bias) self.to_k = ml.Linear(dim, inner_dim, bias = qkv_bias)
self.to_v = bnb.nn.Linear8bitLt(dim, inner_dim, bias = qkv_bias) self.to_v = ml.Linear(dim, inner_dim, bias = qkv_bias)
self.to_out = bnb.nn.Linear8bitLt(inner_dim, dim, bias = attn_out_bias) self.to_out = ml.Linear(inner_dim, dim, bias = attn_out_bias)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
def forward(self, x, pos_emb = None, context = None, mask = None, context_mask = None, **kwargs): def forward(self, x, pos_emb = None, context = None, mask = None, context_mask = None, **kwargs):
@ -460,7 +460,7 @@ class AbsolutePositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len): def __init__(self, dim, max_seq_len):
super().__init__() super().__init__()
# nn.Embedding # nn.Embedding
self.emb = bnb.nn.StableEmbedding(max_seq_len, dim) self.emb = ml.Embedding(max_seq_len, dim)
def forward(self, x): def forward(self, x):
t = torch.arange(x.shape[1], device=x.device) t = torch.arange(x.shape[1], device=x.device)
@ -622,7 +622,7 @@ class PerformerLM(nn.Module):
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
# nn.Embedding # nn.Embedding
self.token_emb = bnb.nn.StableEmbedding(num_tokens, dim) self.token_emb = ml.Embedding(num_tokens, dim)
if rotary_position_emb: if rotary_position_emb:
self.pos_emb = FixedPositionalEmbedding(dim, max_seq_len) self.pos_emb = FixedPositionalEmbedding(dim, max_seq_len)
@ -639,7 +639,7 @@ class PerformerLM(nn.Module):
self.performer = Performer(dim, depth, heads, dim_head, local_attn_heads, local_window_size, causal, ff_mult, nb_features, feature_redraw_interval, reversible, ff_chunks, generalized_attention, kernel_fn, use_scalenorm, use_rezero, ff_glu, ff_dropout, attn_dropout, cross_attend, no_projection, auto_check_redraw, qkv_bias, attn_out_bias, shift_tokens) self.performer = Performer(dim, depth, heads, dim_head, local_attn_heads, local_window_size, causal, ff_mult, nb_features, feature_redraw_interval, reversible, ff_chunks, generalized_attention, kernel_fn, use_scalenorm, use_rezero, ff_glu, ff_dropout, attn_dropout, cross_attend, no_projection, auto_check_redraw, qkv_bias, attn_out_bias, shift_tokens)
self.norm = nn.LayerNorm(dim) self.norm = nn.LayerNorm(dim)
self.to_out = bnb.nn.Linear8bitLt(dim, num_tokens) if not tie_embed else None self.to_out = ml.Linear(dim, num_tokens) if not tie_embed else None
def check_redraw_projections(self): def check_redraw_projections(self):
self.performer.check_redraw_projections() self.performer.check_redraw_projections()

View File

@ -8,7 +8,7 @@ from torch.cuda.amp import autocast
from einops import rearrange, repeat from einops import rearrange, repeat
from contextlib import contextmanager from contextlib import contextmanager
import bitsandbytes as bnb import torch_intermediary as ml
def par(t, nm): def par(t, nm):
@ -356,9 +356,9 @@ class VectorQuantize(nn.Module):
codebook_dim = default(codebook_dim, dim) codebook_dim = default(codebook_dim, dim)
requires_projection = codebook_dim != dim requires_projection = codebook_dim != dim
self.project_in = bnb.nn.Linear8bitLt(dim, codebook_dim) if requires_projection \ self.project_in = ml.Linear(dim, codebook_dim) if requires_projection \
else nn.Identity() else nn.Identity()
self.project_out = bnb.nn.Linear8bitLt(codebook_dim, dim) if requires_projection \ self.project_out = ml.Linear(codebook_dim, dim) if requires_projection \
else nn.Identity() else nn.Identity()
self.eps = eps self.eps = eps

View File

@ -11,7 +11,7 @@ from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange from einops.layers.torch import Rearrange
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
import bitsandbytes as bnb import torch_intermediary as ml
DEFAULT_DIM_HEAD = 64 DEFAULT_DIM_HEAD = 64
@ -127,7 +127,7 @@ class AbsolutePositionalEmbedding(nn.Module):
super().__init__() super().__init__()
self.scale = dim ** -0.5 self.scale = dim ** -0.5
# nn.Embedding # nn.Embedding
self.emb = bnb.nn.StableEmbedding(max_seq_len, dim) self.emb = ml.Embedding(max_seq_len, dim)
def forward(self, x): def forward(self, x):
n = torch.arange(x.shape[1], device=x.device) n = torch.arange(x.shape[1], device=x.device)
@ -157,7 +157,7 @@ class RelativePositionBias(nn.Module):
self.num_buckets = num_buckets self.num_buckets = num_buckets
self.max_distance = max_distance self.max_distance = max_distance
# nn.Embedding # nn.Embedding
self.relative_attention_bias = bnb.nn.StableEmbedding(num_buckets, heads) self.relative_attention_bias = ml.Embedding(num_buckets, heads)
@staticmethod @staticmethod
def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128): def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
@ -363,7 +363,7 @@ class RMSScaleShiftNorm(nn.Module):
self.cdim = 1 self.cdim = 1
self.pdim = -1 self.pdim = -1
else: else:
self.scale_shift_process = bnb.nn.Linear8bitLt(embed_dim, dim * 2, bias=bias) self.scale_shift_process = ml.Linear(embed_dim, dim * 2, bias=bias)
self.cdim = -1 self.cdim = -1
self.pdim = 1 self.pdim = 1
@ -450,7 +450,7 @@ class GLU(nn.Module):
def __init__(self, dim_in, dim_out, activation): def __init__(self, dim_in, dim_out, activation):
super().__init__() super().__init__()
self.act = activation self.act = activation
self.proj = bnb.nn.Linear8bitLt(dim_in, dim_out * 2) self.proj = ml.Linear(dim_in, dim_out * 2)
def forward(self, x): def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1) x, gate = self.proj(x).chunk(2, dim=-1)
@ -475,7 +475,7 @@ class FeedForward(nn.Module):
activation = ReluSquared() if relu_squared else nn.GELU() activation = ReluSquared() if relu_squared else nn.GELU()
project_in = nn.Sequential( project_in = nn.Sequential(
bnb.nn.Linear8bitLt(dim, inner_dim), ml.Linear(dim, inner_dim),
activation activation
) if not glu else GLU(dim, inner_dim, activation) ) if not glu else GLU(dim, inner_dim, activation)
@ -483,7 +483,7 @@ class FeedForward(nn.Module):
project_in, project_in,
nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(), nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(),
nn.Dropout(dropout), nn.Dropout(dropout),
bnb.nn.Linear8bitLt(inner_dim, dim_out) ml.Linear(inner_dim, dim_out)
) )
# init last linear layer to 0 # init last linear layer to 0
@ -538,16 +538,16 @@ class Attention(nn.Module):
qk_dim = int(collab_compression * qk_dim) qk_dim = int(collab_compression * qk_dim)
self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim)) self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim))
self.to_q = bnb.nn.Linear8bitLt(dim, qk_dim, bias=False) self.to_q = ml.Linear(dim, qk_dim, bias=False)
self.to_k = bnb.nn.Linear8bitLt(dim, qk_dim, bias=False) self.to_k = ml.Linear(dim, qk_dim, bias=False)
self.to_v = bnb.nn.Linear8bitLt(dim, v_dim, bias=False) self.to_v = ml.Linear(dim, v_dim, bias=False)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
# add GLU gating for aggregated values, from alphafold2 # add GLU gating for aggregated values, from alphafold2
self.to_v_gate = None self.to_v_gate = None
if gate_values: if gate_values:
self.to_v_gate = bnb.nn.Linear8bitLt(dim, v_dim) self.to_v_gate = ml.Linear(dim, v_dim)
nn.init.constant_(self.to_v_gate.weight, 0) nn.init.constant_(self.to_v_gate.weight, 0)
nn.init.constant_(self.to_v_gate.bias, 1) nn.init.constant_(self.to_v_gate.bias, 1)
@ -584,7 +584,7 @@ class Attention(nn.Module):
# attention on attention # attention on attention
self.attn_on_attn = on_attn self.attn_on_attn = on_attn
out_dim = default(out_dim, dim) out_dim = default(out_dim, dim)
self.to_out = nn.Sequential(bnb.nn.Linear8bitLt(v_dim, out_dim * 2), nn.GLU()) if on_attn else bnb.nn.Linear8bitLt(v_dim, out_dim) self.to_out = nn.Sequential(ml.Linear(v_dim, out_dim * 2), nn.GLU()) if on_attn else ml.Linear(v_dim, out_dim)
self.rel_pos_bias = rel_pos_bias self.rel_pos_bias = rel_pos_bias
if rel_pos_bias: if rel_pos_bias:
@ -1080,7 +1080,7 @@ class ViTransformerWrapper(nn.Module):
self.patch_size = patch_size self.patch_size = patch_size
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_to_embedding = bnb.nn.Linear8bitLt(patch_dim, dim) self.patch_to_embedding = ml.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout) self.dropout = nn.Dropout(emb_dropout)
@ -1139,18 +1139,18 @@ class TransformerWrapper(nn.Module):
self.shift_mem_down = shift_mem_down self.shift_mem_down = shift_mem_down
# nn.Embedding # nn.Embedding
self.token_emb = bnb.nn.StableEmbedding(num_tokens, emb_dim) self.token_emb = ml.Embedding(num_tokens, emb_dim)
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
use_pos_emb and not attn_layers.has_pos_emb) else always(0) use_pos_emb and not attn_layers.has_pos_emb) else always(0)
self.emb_dropout = nn.Dropout(emb_dropout) self.emb_dropout = nn.Dropout(emb_dropout)
self.project_emb = bnb.nn.Linear8bitLt(emb_dim, dim) if emb_dim != dim else nn.Identity() self.project_emb = ml.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
self.attn_layers = attn_layers self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim) self.norm = nn.LayerNorm(dim)
self.init_() self.init_()
self.to_logits = bnb.nn.Linear8bitLt(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() self.to_logits = ml.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
# memory tokens (like [cls]) from Memory Transformers paper # memory tokens (like [cls]) from Memory Transformers paper
num_memory_tokens = default(num_memory_tokens, 0) num_memory_tokens = default(num_memory_tokens, 0)
@ -1237,12 +1237,12 @@ class ContinuousTransformerWrapper(nn.Module):
use_pos_emb and not attn_layers.has_pos_emb) else always(0) use_pos_emb and not attn_layers.has_pos_emb) else always(0)
self.emb_dropout = nn.Dropout(emb_dropout) self.emb_dropout = nn.Dropout(emb_dropout)
self.project_in = bnb.nn.Linear8bitLt(dim_in, dim) if exists(dim_in) else nn.Identity() self.project_in = ml.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
self.attn_layers = attn_layers self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim) self.norm = nn.LayerNorm(dim)
self.project_out = bnb.nn.Linear8bitLt(dim, dim_out) if exists(dim_out) else nn.Identity() self.project_out = ml.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
def forward( def forward(
self, self,

View File

@ -4,7 +4,7 @@ import torch.nn.functional as F
from torch import einsum from torch import einsum
from utils.weight_scheduler import LinearDecayWeightScheduler from utils.weight_scheduler import LinearDecayWeightScheduler
import bitsandbytes as bnb import torch_intermediary as ml
class GumbelQuantizer(nn.Module): class GumbelQuantizer(nn.Module):
@ -12,7 +12,7 @@ class GumbelQuantizer(nn.Module):
super().__init__() super().__init__()
self.to_logits = nn.Conv1d(inp_dim, num_tokens, 1) self.to_logits = nn.Conv1d(inp_dim, num_tokens, 1)
# nn.Embedding # nn.Embedding
self.codebook = bnb.nn.StableEmbedding(num_tokens, codebook_dim) self.codebook = ml.Embedding(num_tokens, codebook_dim)
self.straight_through = straight_through self.straight_through = straight_through
self.temperature_scheduler = LinearDecayWeightScheduler(10, 5000, .9, 2000) self.temperature_scheduler = LinearDecayWeightScheduler(10, 5000, .9, 2000)
self.step = 0 self.step = 0

View File

@ -4,7 +4,7 @@ import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from models.arch_util import l2norm, sample_vectors, default, ema_inplace from models.arch_util import l2norm, sample_vectors, default, ema_inplace
import bitsandbytes as bnb import torch_intermediary as ml
def kmeans(samples, num_clusters, num_iters = 10, use_cosine_sim = False): def kmeans(samples, num_clusters, num_iters = 10, use_cosine_sim = False):
@ -185,8 +185,8 @@ class VectorQuantize(nn.Module):
codebook_dim = default(codebook_dim, dim) codebook_dim = default(codebook_dim, dim)
requires_projection = codebook_dim != dim requires_projection = codebook_dim != dim
self.project_in = bnb.nn.Linear8bitLt(dim, codebook_dim) if requires_projection else nn.Identity() self.project_in = ml.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
self.project_out = bnb.nn.Linear8bitLt(codebook_dim, dim) if requires_projection else nn.Identity() self.project_out = ml.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
self.eps = eps self.eps = eps

View File

@ -0,0 +1,41 @@
"""
from bitsandbytes.nn import Linear8bitLt as Linear
from bitsandbytes.nn import StableEmbedding as Embedding
from bitsandbytes.optim.adam import Adam8bit as Adam
from bitsandbytes.optim.adamw import AdamW8bit as AdamW
"""
"""
from torch.nn import Linear
from torch.nn import Embedding
from torch.optim.adam import Adam
from torch.optim.adamw import AdamW
"""
OVERRIDE_LINEAR = False
OVERRIDE_EMBEDDING = False
OVERRIDE_ADAM = True
OVERRIDE_ADAMW = True
USE_STABLE_EMBEDDING = True
if OVERRIDE_LINEAR:
from bitsandbytes.nn import Linear8bitLt as Linear
else:
from torch.nn import Linear
if OVERRIDE_EMBEDDING:
if USE_STABLE_EMBEDDING:
from bitsandbytes.nn import StableEmbedding as Embedding
else:
from bitsandbytes.nn import Embedding as Embedding
else:
from torch.nn import Embedding
if OVERRIDE_ADAM:
from bitsandbytes.optim.adam import Adam8bit as Adam
else:
from torch.optim.adam import Adam
if OVERRIDE_ADAMW:
from bitsandbytes.optim.adamw import AdamW8bit as AdamW
else:
from torch.optim.adamw import AdamW

View File

@ -21,7 +21,7 @@ import torchvision.utils as utils
from utils.loss_accumulator import LossAccumulator, InfStorageLossAccumulator from utils.loss_accumulator import LossAccumulator, InfStorageLossAccumulator
from utils.util import opt_get, denormalize from utils.util import opt_get, denormalize
import bitsandbytes as bnb import torch_intermediary as ml
logger = logging.getLogger('base') logger = logging.getLogger('base')
@ -338,7 +338,7 @@ class ExtensibleTrainer(BaseModel):
for net in self.networks.values(): for net in self.networks.values():
for mod in net.modules(): for mod in net.modules():
fan_in = -1 fan_in = -1
if isinstance(mod, bnb.nn.Linear8bitLt): if isinstance(mod, ml.Linear):
fan_in = mod.weight.data.shape[1] fan_in = mod.weight.data.shape[1]
elif isinstance(mod, nn.Conv1d): elif isinstance(mod, nn.Conv1d):
fan_in = mod.weight.data.shape[0] fan_in = mod.weight.data.shape[0]

View File

@ -7,7 +7,7 @@ import torch.nn as nn
import trainer.networks as networks import trainer.networks as networks
import trainer.lr_scheduler as lr_scheduler import trainer.lr_scheduler as lr_scheduler
from .base_model import BaseModel from .base_model import BaseModel
import bitsandbytes as bnb import torch_intermediary as ml
logger = logging.getLogger('base') logger = logging.getLogger('base')
@ -43,7 +43,7 @@ class FeatureModel(BaseModel):
if self.rank <= 0: if self.rank <= 0:
logger.warning('Params [{:s}] will not optimize.'.format(k)) logger.warning('Params [{:s}] will not optimize.'.format(k))
# torch.optim.Adam # torch.optim.Adam
self.optimizer_G = bnb.optim.Adam8bit(optim_params, lr=train_opt['lr_G'], self.optimizer_G = ml.Adam(optim_params, lr=train_opt['lr_G'],
weight_decay=wd_G, weight_decay=wd_G,
betas=(train_opt['beta1_G'], train_opt['beta2_G'])) betas=(train_opt['beta1_G'], train_opt['beta2_G']))
self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_G)

View File

@ -3,7 +3,7 @@ from collections import Counter
from collections import defaultdict from collections import defaultdict
import torch import torch
from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import _LRScheduler
import bitsandbytes as bnb import torch_intermediary as ml
from utils.util import opt_get from utils.util import opt_get
@ -137,7 +137,7 @@ class CosineAnnealingLR_Restart(_LRScheduler):
if __name__ == "__main__": if __name__ == "__main__":
#torch.optim.Adam #torch.optim.Adam
optimizer = bnb.optim.Adam8bit([torch.zeros(3, 64, 3, 3)], lr=1e-4, weight_decay=0, optimizer = ml.Adam([torch.zeros(3, 64, 3, 3)], lr=1e-4, weight_decay=0,
betas=(0.9, 0.99)) betas=(0.9, 0.99))
############################## ##############################
# MultiStepLR_Restart # MultiStepLR_Restart

View File

@ -12,7 +12,7 @@ from utils.util import recursively_detach, opt_get, clip_grad_norm
logger = logging.getLogger('base') logger = logging.getLogger('base')
import bitsandbytes as bnb import torch_intermediary as ml
# Defines the expected API for a single training step # Defines the expected API for a single training step
class ConfigurableStep(Module): class ConfigurableStep(Module):
@ -84,7 +84,7 @@ class ConfigurableStep(Module):
norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d, norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d,
nn.BatchNorm3d, nn.InstanceNorm3d, nn.GroupNorm, nn.LayerNorm) nn.BatchNorm3d, nn.InstanceNorm3d, nn.GroupNorm, nn.LayerNorm)
# nn.Embedding # nn.Embedding
emb_modules = (bnb.nn.StableEmbedding, nn.EmbeddingBag) emb_modules = (ml.Embedding, nn.EmbeddingBag)
param_names_notweights = set() param_names_notweights = set()
all_param_names = set() all_param_names = set()
param_map = {} param_map = {}
@ -126,7 +126,7 @@ class ConfigurableStep(Module):
{ 'params': params_notweights, 'weight_decay': 0 } { 'params': params_notweights, 'weight_decay': 0 }
] ]
# torch.optim.AdamW # torch.optim.AdamW
opt = bnb.optim.AdamW8bit(groups, lr=opt_config['lr'], opt = ml.AdamW(groups, lr=opt_config['lr'],
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2), weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999))) betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
opt._group_names = [params_names_weights, params_names_notweights] opt._group_names = [params_names_weights, params_names_notweights]
@ -145,7 +145,7 @@ class ConfigurableStep(Module):
# parameters and just use a normal AdamW implementation. In a large network, these weights will normally # parameters and just use a normal AdamW implementation. In a large network, these weights will normally
# be a tiny fraction of the total weights. # be a tiny fraction of the total weights.
# torch.optim.AdamW # torch.optim.AdamW
opt_unweighted = bnb.optim.AdamW8bit(params_notweights, lr=opt_config['lr'], weight_decay=0, opt_unweighted = ml.AdamW(params_notweights, lr=opt_config['lr'], weight_decay=0,
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999))) betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
opt_unweighted._config = opt_config opt_unweighted._config = opt_config
opt_unweighted._config['network'] = net_name opt_unweighted._config['network'] = net_name
@ -153,7 +153,7 @@ class ConfigurableStep(Module):
self.optimizers.append(opt_unweighted) self.optimizers.append(opt_unweighted)
# torch.optim.AdamW # torch.optim.AdamW
opt = ZeroRedundancyOptimizer(params_weights, optimizer_class=bnb.optim.AdamW8bit, lr=opt_config['lr'], opt = ZeroRedundancyOptimizer(params_weights, optimizer_class=ml.AdamW, lr=opt_config['lr'],
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2), weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999))) betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
opt.param_groups[0]['initial_lr'] = opt_config['lr'] opt.param_groups[0]['initial_lr'] = opt_config['lr']
@ -168,7 +168,7 @@ class ConfigurableStep(Module):
elif self.step_opt['optimizer'] == 'lamb': elif self.step_opt['optimizer'] == 'lamb':
from trainer.optimizers.lamb import Lamb from trainer.optimizers.lamb import Lamb
# torch.optim.AdamW # torch.optim.AdamW
opt_unweighted = bnb.optim.AdamW8bit(params_notweights, lr=opt_config['lr'], weight_decay=0, opt_unweighted = ml.AdamW(params_notweights, lr=opt_config['lr'], weight_decay=0,
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999))) betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
opt_unweighted._config = opt_config opt_unweighted._config = opt_config
opt_unweighted._config['network'] = net_name opt_unweighted._config['network'] = net_name