I sucked off the hyptothetical wizard again, just using BNB's ADAM optimizer nets HUGE savings, but I don't know the output costs, will need to test
This commit is contained in:
parent
01c0941a40
commit
6676c89c0e
|
@ -9,7 +9,7 @@ import torch.nn.utils.spectral_norm as SpectralNorm
|
|||
from math import sqrt
|
||||
|
||||
from utils.util import checkpoint
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
|
||||
def exists(val):
|
||||
|
@ -74,7 +74,7 @@ def initialize_weights(net_l, scale=1):
|
|||
m.weight.data *= scale # for residual block
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, bnb.nn.Linear8bitLt):
|
||||
elif isinstance(m, ml.Linear):
|
||||
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||
m.weight.data *= scale
|
||||
if m.bias is not None:
|
||||
|
@ -109,7 +109,7 @@ def default_init_weights(module, scale=1):
|
|||
if isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m, a=0, mode='fan_in', bias=0)
|
||||
m.weight.data *= scale
|
||||
elif isinstance(m, bnb.nn.Linear8bitLt):
|
||||
elif isinstance(m, ml.Linear):
|
||||
kaiming_init(m, a=0, mode='fan_in', bias=0)
|
||||
m.weight.data *= scale
|
||||
|
||||
|
@ -142,7 +142,7 @@ def linear(*args, **kwargs):
|
|||
"""
|
||||
Create a linear module.
|
||||
"""
|
||||
return bnb.nn.Linear8bitLt(*args, **kwargs)
|
||||
return ml.Linear(*args, **kwargs)
|
||||
|
||||
|
||||
def avg_pool_nd(dims, *args, **kwargs):
|
||||
|
|
|
@ -9,7 +9,7 @@ from data.audio.unsupervised_audio_dataset import load_audio
|
|||
from models.audio.tts.tacotron2.text import sequence_to_text
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
|
||||
def only_letters(string):
|
||||
|
@ -52,7 +52,7 @@ class Wav2VecWrapper(nn.Module):
|
|||
self.w2v = Wav2Vec2ForCTC.from_pretrained(basis_model)
|
||||
# Perform some surgery to get the model we actually want.
|
||||
self.w2v.wav2vec2.encoder.gradient_checkpointing = checkpointing_enabled
|
||||
self.w2v.lm_head = bnb.nn.Linear8bitLt(self.w2v.config.hidden_size, vocab_size)
|
||||
self.w2v.lm_head = ml.Linear(self.w2v.config.hidden_size, vocab_size)
|
||||
self.w2v.config.vocab_size = vocab_size
|
||||
self.w2v.config.pad_token_id = 0
|
||||
self.w2v.config.ctc_loss_reduction = 'sum'
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
|||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
from typing import Type, Any, Callable, Union, List, Optional
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
|
||||
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
||||
|
@ -173,7 +173,7 @@ class ResNet(nn.Module):
|
|||
self.layer4 = self._make_layer(block, 512, layers[3], stride=4,
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
||||
self.fc = bnb.nn.Linear8bitLt(512 * block.expansion, num_classes)
|
||||
self.fc = ml.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d):
|
||||
|
|
|
@ -15,14 +15,14 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
|
|||
from models.arch_util import ResBlock
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
|
||||
class Mel2Vec2FeatureProjection(nn.Module):
|
||||
def __init__(self, inner_dim, dropout):
|
||||
super().__init__()
|
||||
self.layer_norm = nn.LayerNorm(inner_dim, eps=1e-5)
|
||||
self.projection = bnb.nn.Linear8bitLt(inner_dim, inner_dim)
|
||||
self.projection = ml.Linear(inner_dim, inner_dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
|
@ -59,10 +59,10 @@ class Wav2Vec2Attention(nn.Module):
|
|||
self.scaling = self.head_dim**-0.5
|
||||
self.is_decoder = is_decoder
|
||||
|
||||
self.k_proj = bnb.nn.Linear8bitLt(embed_dim, embed_dim, bias=bias)
|
||||
self.v_proj = bnb.nn.Linear8bitLt(embed_dim, embed_dim, bias=bias)
|
||||
self.q_proj = bnb.nn.Linear8bitLt(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = bnb.nn.Linear8bitLt(embed_dim, embed_dim, bias=bias)
|
||||
self.k_proj = ml.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.v_proj = ml.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.q_proj = ml.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = ml.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
@ -183,10 +183,10 @@ class Wav2Vec2FeedForward(nn.Module):
|
|||
super().__init__()
|
||||
self.intermediate_dropout = nn.Dropout(dropout)
|
||||
|
||||
self.intermediate_dense = bnb.nn.Linear8bitLt(hidden_size, intermediate_size)
|
||||
self.intermediate_dense = ml.Linear(hidden_size, intermediate_size)
|
||||
self.intermediate_act_fn = F.gelu
|
||||
|
||||
self.output_dense = bnb.nn.Linear8bitLt(intermediate_size, hidden_size)
|
||||
self.output_dense = ml.Linear(intermediate_size, hidden_size)
|
||||
self.output_dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
|
@ -430,7 +430,7 @@ class Mel2Vec(nn.Module):
|
|||
k = math.sqrt(1 / module.projection.in_features)
|
||||
nn.init.uniform_(module.projection.weight, a=-k, b=k)
|
||||
nn.init.uniform_(module.projection.bias, a=-k, b=k)
|
||||
elif isinstance(module, bnb.nn.Linear8bitLt):
|
||||
elif isinstance(module, ml.Linear):
|
||||
if self.disable_custom_linear_init:
|
||||
return
|
||||
module.weight.data.normal_(mean=0.0, std=self.linear_init_scale)
|
||||
|
@ -511,7 +511,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
|||
self.codevectors = nn.Parameter(
|
||||
torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups)
|
||||
)
|
||||
self.weight_proj = bnb.nn.Linear8bitLt(proj_dim, self.num_groups * self.num_vars)
|
||||
self.weight_proj = ml.Linear(proj_dim, self.num_groups * self.num_vars)
|
||||
|
||||
# can be decayed for training
|
||||
self.temperature = 2
|
||||
|
@ -607,8 +607,8 @@ class ContrastiveTrainingWrapper(nn.Module):
|
|||
self.inp_length_factor = inp_length_multiplier
|
||||
|
||||
# make sure that project_hid & project_q are initialized like normal linear layers
|
||||
self.project_hid = bnb.nn.Linear8bitLt(inner_dim, self.quantizer.codevector_dim)
|
||||
self.project_q = bnb.nn.Linear8bitLt(self.quantizer.codevector_dim, self.quantizer.codevector_dim)
|
||||
self.project_hid = ml.Linear(inner_dim, self.quantizer.codevector_dim)
|
||||
self.project_q = ml.Linear(self.quantizer.codevector_dim, self.quantizer.codevector_dim)
|
||||
|
||||
self.reconstruction = do_reconstruction_loss
|
||||
if do_reconstruction_loss:
|
||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import GPT2Config, GPT2Model
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
from models.arch_util import AttentionBlock, ResBlock
|
||||
from models.audio.tts.lucidrains_dvae import DiscreteVAE
|
||||
|
@ -57,8 +57,8 @@ class ConditioningAR(nn.Module):
|
|||
del self.gpt.wte # Unused, we'll do our own embeddings.
|
||||
|
||||
# nn.Embedding
|
||||
self.embeddings = bnb.nn.StableEmbedding(num_vectors, dim)
|
||||
self.head = bnb.nn.Linear8bitLt(dim, num_vectors)
|
||||
self.embeddings = ml.Embedding(num_vectors, dim)
|
||||
self.head = ml.Linear(dim, num_vectors)
|
||||
|
||||
def forward(self, cheater_codes, conditioning, code_lengths=None, return_latent=False):
|
||||
unused_params = []
|
||||
|
|
|
@ -17,7 +17,7 @@ import numpy as np
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
from math import sqrt
|
||||
|
||||
|
@ -25,7 +25,7 @@ from torch.utils.checkpoint import checkpoint
|
|||
|
||||
from trainer.networks import register_model
|
||||
|
||||
Linear = bnb.nn.Linear8bitLt
|
||||
Linear = ml.Linear
|
||||
ConvTranspose2d = nn.ConvTranspose2d
|
||||
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import autocast
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
from models.arch_util import ResBlock
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
|
@ -24,7 +24,7 @@ class MultiGroupEmbedding(nn.Module):
|
|||
def __init__(self, tokens, groups, dim):
|
||||
super().__init__()
|
||||
# nn.Embedding
|
||||
self.m = nn.ModuleList([bnb.nn.StableEmbedding(tokens, dim // groups) for _ in range(groups)])
|
||||
self.m = nn.ModuleList([ml.Embedding(tokens, dim // groups) for _ in range(groups)])
|
||||
|
||||
def forward(self, x):
|
||||
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
|
||||
|
@ -161,7 +161,7 @@ class FlatDiffusion(nn.Module):
|
|||
# transformer network.
|
||||
if in_groups is None:
|
||||
# nn.Embedding
|
||||
self.embeddings = bnb.nn.StableEmbedding(token_count, model_channels)
|
||||
self.embeddings = ml.Embedding(token_count, model_channels)
|
||||
else:
|
||||
self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels)
|
||||
self.latent_conditioner = nn.Sequential(
|
||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
|||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import GPT2Config, GPT2Model
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
from models.arch_util import AttentionBlock, ResBlock
|
||||
from models.audio.music.music_quantizer import MusicQuantizer
|
||||
|
@ -138,8 +138,8 @@ class GptMusicLower(nn.Module):
|
|||
del self.gpt.wte # Unused, we'll do our own embeddings.
|
||||
|
||||
# nn.Embedding
|
||||
self.embeddings = nn.ModuleList([bnb.nn.StableEmbedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)])
|
||||
self.heads = nn.ModuleList([bnb.nn.Linear8bitLt(dim, num_target_vectors) for _ in range(num_vaes)])
|
||||
self.embeddings = nn.ModuleList([ml.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)])
|
||||
self.heads = nn.ModuleList([ml.Linear(dim, num_target_vectors) for _ in range(num_vaes)])
|
||||
|
||||
def forward(self, mel, conditioning, return_latent=False):
|
||||
unused_params = []
|
||||
|
@ -241,8 +241,8 @@ class GptMusicUpper(nn.Module):
|
|||
del self.gpt.wte # Unused, we'll do our own embeddings.
|
||||
|
||||
# nn.Embedding
|
||||
self.embeddings = nn.ModuleList([bnb.nn.StableEmbedding(num_upper_vectors, dim // num_upper_groups) for _ in range(num_upper_groups)])
|
||||
self.heads = nn.ModuleList([bnb.nn.Linear8bitLt(dim, num_upper_vectors) for _ in range(num_upper_groups)])
|
||||
self.embeddings = nn.ModuleList([ml.Embedding(num_upper_vectors, dim // num_upper_groups) for _ in range(num_upper_groups)])
|
||||
self.heads = nn.ModuleList([ml.Linear(dim, num_upper_vectors) for _ in range(num_upper_groups)])
|
||||
|
||||
|
||||
def forward(self, mel, conditioning, return_latent=False):
|
||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import GPT2Config, GPT2Model
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
from models.arch_util import AttentionBlock, ResBlock
|
||||
from models.audio.tts.lucidrains_dvae import DiscreteVAE
|
||||
|
@ -75,8 +75,8 @@ class GptMusicLower(nn.Module):
|
|||
del self.gpt.wte # Unused, we'll do our own embeddings.
|
||||
|
||||
# nn.Embedding
|
||||
self.embeddings = nn.ModuleList([bnb.nn.StableEmbedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)])
|
||||
self.heads = nn.ModuleList([bnb.nn.Linear8bitLt(dim, num_target_vectors) for _ in range(num_vaes)])
|
||||
self.embeddings = nn.ModuleList([ml.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)])
|
||||
self.heads = nn.ModuleList([ml.Linear(dim, num_target_vectors) for _ in range(num_vaes)])
|
||||
|
||||
def forward(self, mel, return_latent=False):
|
||||
unused_params = []
|
||||
|
|
|
@ -3,7 +3,7 @@ import functools
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
from models.diffusion.nn import timestep_embedding
|
||||
from models.lucidrains.vq import VectorQuantize
|
||||
|
@ -22,8 +22,8 @@ class SelfClassifyingHead(nn.Module):
|
|||
use_rmsnorm=True, ff_glu=True, do_checkpointing=False)
|
||||
self.quantizer = VectorQuantize(out_dim, classes, use_cosine_sim=False, threshold_ema_dead_code=2,
|
||||
sample_codebook_temp=init_temperature)
|
||||
self.to_output = bnb.nn.Linear8bitLt(dim, out_dim)
|
||||
self.to_decoder = bnb.nn.Linear8bitLt(out_dim, dim)
|
||||
self.to_output = ml.Linear(dim, out_dim)
|
||||
self.to_decoder = ml.Linear(out_dim, dim)
|
||||
|
||||
def do_ar_step(self, x, used_codes):
|
||||
h = self.dec(x)
|
||||
|
@ -91,7 +91,7 @@ class InstrumentQuantizer(nn.Module):
|
|||
"""
|
||||
super().__init__()
|
||||
self.op_dim = op_dim
|
||||
self.proj = bnb.nn.Linear8bitLt(op_dim, dim)
|
||||
self.proj = ml.Linear(op_dim, dim)
|
||||
self.encoder = nn.ModuleList([VectorResBlock(dim, dropout) for _ in range(enc_depth)])
|
||||
self.heads = SelfClassifyingHead(dim, num_classes, op_dim, head_depth, class_seq_len, dropout, max_temp)
|
||||
self.min_gumbel_temperature = min_temp
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
from transformers import GPT2Config, GPT2Model
|
||||
|
||||
from trainer.networks import register_model
|
||||
|
@ -19,8 +19,8 @@ class Mel2VecCodesGpt(nn.Module):
|
|||
self.gpt = GPT2Model(self.config)
|
||||
del self.gpt.wte # Unused, we'll do our own embeddings.
|
||||
# nn.Embedding
|
||||
self.embeddings = nn.ModuleList([bnb.nn.StableEmbedding(num_vectors, dim//num_groups) for _ in range(num_groups)])
|
||||
self.heads = nn.ModuleList([bnb.nn.Linear8bitLt(dim, num_vectors) for _ in range(num_groups)])
|
||||
self.embeddings = nn.ModuleList([ml.Embedding(num_vectors, dim//num_groups) for _ in range(num_groups)])
|
||||
self.heads = nn.ModuleList([ml.Linear(dim, num_vectors) for _ in range(num_groups)])
|
||||
|
||||
def forward(self, codes):
|
||||
assert codes.shape[-1] == self.num_groups
|
||||
|
|
|
@ -3,7 +3,7 @@ import functools
|
|||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
from models.arch_util import zero_module
|
||||
from models.vqvae.vqvae import Quantize
|
||||
|
@ -76,7 +76,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
|||
self.codevectors = nn.Parameter(
|
||||
torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups)
|
||||
)
|
||||
self.weight_proj = bnb.nn.Linear8bitLt(proj_dim, self.num_groups * self.num_vars)
|
||||
self.weight_proj = ml.Linear(proj_dim, self.num_groups * self.num_vars)
|
||||
|
||||
# can be decayed for training
|
||||
self.temperature = 2
|
||||
|
|
|
@ -3,7 +3,7 @@ import functools
|
|||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
from models.arch_util import zero_module
|
||||
from models.vqvae.vqvae import Quantize
|
||||
|
@ -88,7 +88,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
|||
self.codevectors = nn.Parameter(
|
||||
torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups)
|
||||
)
|
||||
self.weight_proj = bnb.nn.Linear8bitLt(proj_dim, self.num_groups * self.num_vars)
|
||||
self.weight_proj = ml.Linear(proj_dim, self.num_groups * self.num_vars)
|
||||
|
||||
# can be decayed for training
|
||||
self.temperature = 2
|
||||
|
|
|
@ -8,7 +8,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
import torchvision
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.unet_diffusion import TimestepBlock
|
||||
|
@ -56,12 +56,12 @@ class ConcatAttentionBlock(TimestepBlock):
|
|||
self.prenorm = RMSScaleShiftNorm(trunk_dim, embed_dim=time_embed_dim, bias=False)
|
||||
if cond_projection:
|
||||
self.tdim = trunk_dim+cond_dim_hidden
|
||||
self.cond_project = bnb.nn.Linear8bitLt(cond_dim_in, cond_dim_hidden)
|
||||
self.cond_project = ml.Linear(cond_dim_in, cond_dim_hidden)
|
||||
else:
|
||||
self.tdim = trunk_dim
|
||||
self.block1 = SubBlock(self.tdim, contraction_dim, heads, dropout, use_conv)
|
||||
self.block2 = SubBlock(self.tdim+contraction_dim*2, contraction_dim, heads, dropout, use_conv)
|
||||
self.out = bnb.nn.Linear8bitLt(contraction_dim*4, trunk_dim, bias=False)
|
||||
self.out = ml.Linear(contraction_dim*4, trunk_dim, bias=False)
|
||||
self.out.weight.data.zero_()
|
||||
|
||||
def forward(self, x, cond, timestep_emb, rotary_emb):
|
||||
|
@ -89,7 +89,7 @@ class ConditioningEncoder(nn.Module):
|
|||
self.init = nn.Conv1d(cond_dim, embedding_dim, kernel_size=1)
|
||||
self.time_proj = time_proj
|
||||
if time_proj:
|
||||
self.time_proj = bnb.nn.Linear8bitLt(time_embed_dim, embedding_dim)
|
||||
self.time_proj = ml.Linear(time_embed_dim, embedding_dim)
|
||||
self.attn = Encoder(
|
||||
dim=embedding_dim,
|
||||
depth=attn_blocks,
|
||||
|
|
|
@ -4,7 +4,7 @@ from time import time
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
from models.arch_util import ResBlock
|
||||
from models.audio.music.gpt_music2 import UpperEncoder, GptMusicLower
|
||||
|
@ -29,7 +29,7 @@ class MultiGroupEmbedding(nn.Module):
|
|||
def __init__(self, tokens, groups, dim):
|
||||
super().__init__()
|
||||
# nn.Embedding
|
||||
self.m = nn.ModuleList([bnb.nn.StableEmbedding(tokens, dim // groups) for _ in range(groups)])
|
||||
self.m = nn.ModuleList([ml.Embedding(tokens, dim // groups) for _ in range(groups)])
|
||||
|
||||
def forward(self, x):
|
||||
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
|
||||
|
@ -70,7 +70,7 @@ class ConcatAttentionBlock(TimestepBlock):
|
|||
self.prenorm = RMSScaleShiftNorm(trunk_dim, embed_dim=time_embed_dim, bias=False)
|
||||
self.block1 = SubBlock(trunk_dim, contraction_dim, heads, dropout)
|
||||
self.block2 = SubBlock(trunk_dim+contraction_dim*2, contraction_dim, heads, dropout)
|
||||
self.out = bnb.nn.Linear8bitLt(contraction_dim*4, trunk_dim, bias=False)
|
||||
self.out = ml.Linear(contraction_dim*4, trunk_dim, bias=False)
|
||||
self.out.weight.data.zero_()
|
||||
|
||||
def forward(self, x, timestep_emb, rotary_emb):
|
||||
|
@ -131,7 +131,7 @@ class TransformerDiffusion(nn.Module):
|
|||
)
|
||||
|
||||
prenet_heads = prenet_channels//64
|
||||
self.input_converter = bnb.nn.Linear8bitLt(input_vec_dim, prenet_channels)
|
||||
self.input_converter = ml.Linear(input_vec_dim, prenet_channels)
|
||||
self.code_converter = Encoder(
|
||||
dim=prenet_channels,
|
||||
depth=prenet_layers,
|
||||
|
@ -147,7 +147,7 @@ class TransformerDiffusion(nn.Module):
|
|||
|
||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels))
|
||||
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
|
||||
self.intg = bnb.nn.Linear8bitLt(prenet_channels*2, model_channels)
|
||||
self.intg = ml.Linear(prenet_channels*2, model_channels)
|
||||
self.layers = TimestepRotaryEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_embed_dim, num_heads, dropout) for _ in range(num_layers)])
|
||||
|
||||
self.out = nn.Sequential(
|
||||
|
|
|
@ -5,7 +5,7 @@ from random import randrange
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask, cGLU, \
|
||||
RelativeQKBias
|
||||
|
@ -71,13 +71,13 @@ class ConditioningEncoder(nn.Module):
|
|||
attn = []
|
||||
self.init = nn.Conv1d(spec_dim, hidden_dim, kernel_size=5, stride=2)
|
||||
# nn.Embedding
|
||||
self.resolution_embedding = bnb.nn.StableEmbedding(num_resolutions, hidden_dim)
|
||||
self.resolution_embedding = ml.Embedding(num_resolutions, hidden_dim)
|
||||
self.resolution_embedding.weight.data.mul(.1) # Reduces the relative influence of this embedding from the start.
|
||||
for a in range(attn_blocks):
|
||||
attn.append(AttentionBlock(hidden_dim, num_attn_heads, do_checkpoint=do_checkpointing))
|
||||
attn.append(ResBlock(hidden_dim, dims=1, checkpointing_enabled=do_checkpointing))
|
||||
self.attn = nn.Sequential(*attn)
|
||||
self.out = bnb.nn.Linear8bitLt(hidden_dim, out_dim, bias=False)
|
||||
self.out = ml.Linear(hidden_dim, out_dim, bias=False)
|
||||
self.dim = hidden_dim
|
||||
self.do_checkpointing = do_checkpointing
|
||||
|
||||
|
@ -134,7 +134,7 @@ class TransformerDiffusion(nn.Module):
|
|||
linear(time_embed_dim, time_proj_dim),
|
||||
)
|
||||
# nn.Embedding
|
||||
self.resolution_embed = bnb.nn.StableEmbedding(resolution_steps, time_proj_dim)
|
||||
self.resolution_embed = ml.Embedding(resolution_steps, time_proj_dim)
|
||||
self.conditioning_encoder = ConditioningEncoder(in_channels, model_channels, cond_proj_dim, resolution_steps, num_attn_heads=model_channels//64)
|
||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,cond_proj_dim))
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ import torch as th
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision # For debugging, not actually used.
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
from models.audio.music.gpt_music import GptMusicLower
|
||||
from models.audio.music.music_quantizer import MusicQuantizer
|
||||
|
@ -491,7 +491,7 @@ class UNetMusicModel(nn.Module):
|
|||
)
|
||||
|
||||
if self.ar_prior:
|
||||
self.ar_input = bnb.nn.Linear8bitLt(input_vec_dim, model_channels)
|
||||
self.ar_input = ml.Linear(input_vec_dim, model_channels)
|
||||
self.ar_prior_intg = Encoder(
|
||||
dim=model_channels,
|
||||
depth=4,
|
||||
|
@ -505,7 +505,7 @@ class UNetMusicModel(nn.Module):
|
|||
ff_mult=1,
|
||||
)
|
||||
else:
|
||||
self.input_converter = bnb.nn.Linear8bitLt(input_vec_dim, model_channels)
|
||||
self.input_converter = ml.Linear(input_vec_dim, model_channels)
|
||||
self.code_converter = Encoder(
|
||||
dim=model_channels,
|
||||
depth=4,
|
||||
|
@ -523,7 +523,7 @@ class UNetMusicModel(nn.Module):
|
|||
|
||||
if self.num_classes is not None:
|
||||
# nn.Embedding
|
||||
self.label_emb = bnb.nn.StableEmbedding(num_classes, time_embed_dim)
|
||||
self.label_emb = ml.Embedding(num_classes, time_embed_dim)
|
||||
self.use_raw_y_as_embedding = use_raw_y_as_embedding
|
||||
assert not ((self.num_classes is not None) and use_raw_y_as_embedding) # These are mutually-exclusive.
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ from random import random
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
from models.audio.tts.unet_diffusion_tts7 import CheckpointedLayer
|
||||
from models.lucidrains.x_transformers import Encoder
|
||||
|
@ -38,11 +38,11 @@ class CtcCodeGenerator(nn.Module):
|
|||
pred_codes = (max_pad+1)*(max_repeat+1)
|
||||
|
||||
# nn.Embedding
|
||||
self.position_embedding = bnb.nn.StableEmbedding(max_length, model_dim)
|
||||
self.position_embedding = ml.Embedding(max_length, model_dim)
|
||||
# nn.Embedding
|
||||
self.codes_embedding = bnb.nn.StableEmbedding(ctc_codes, model_dim)
|
||||
self.codes_embedding = ml.Embedding(ctc_codes, model_dim)
|
||||
# nn.Embedding
|
||||
self.recursive_embedding = bnb.nn.StableEmbedding(pred_codes, model_dim)
|
||||
self.recursive_embedding = ml.Embedding(pred_codes, model_dim)
|
||||
self.mask_embedding = nn.Parameter(torch.randn(model_dim))
|
||||
self.encoder = Encoder(
|
||||
dim=model_dim,
|
||||
|
@ -54,8 +54,8 @@ class CtcCodeGenerator(nn.Module):
|
|||
ff_glu=True,
|
||||
rotary_pos_emb=True,
|
||||
)
|
||||
self.pred_head = bnb.nn.Linear8bitLt(model_dim, pred_codes)
|
||||
self.confidence_head = bnb.nn.Linear8bitLt(model_dim, 1)
|
||||
self.pred_head = ml.Linear(model_dim, pred_codes)
|
||||
self.confidence_head = ml.Linear(model_dim, 1)
|
||||
|
||||
def inference(self, codes, pads, repeats):
|
||||
position_h = self.position_embedding(torch.arange(0, codes.shape[-1], device=codes.device))
|
||||
|
|
|
@ -5,7 +5,7 @@ from functools import partial
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
from x_transformers.x_transformers import groupby_prefix_and_trim, FixedPositionalEmbedding, default, RotaryEmbedding, \
|
||||
DEFAULT_DIM_HEAD, RelativePositionBias, LearnedAlibiPositionalBias, AlibiPositionalBias, ScaleNorm, RMSNorm, \
|
||||
|
@ -18,7 +18,7 @@ class TimeIntegrationBlock(nn.Module):
|
|||
super().__init__()
|
||||
self.emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
bnb.nn.Linear8bitLt(
|
||||
ml.Linear(
|
||||
time_emb_dim,
|
||||
2 * dim
|
||||
),
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
|
||||
from models.diffusion.nn import normalization, conv_nd, zero_module
|
||||
|
@ -139,7 +139,7 @@ class AudioMiniEncoderWithClassifierHead(nn.Module):
|
|||
def __init__(self, classes, distribute_zero_label=True, **kwargs):
|
||||
super().__init__()
|
||||
self.enc = AudioMiniEncoder(**kwargs)
|
||||
self.head = bnb.nn.Linear8bitLt(self.enc.dim, classes)
|
||||
self.head = ml.Linear(self.enc.dim, classes)
|
||||
self.num_classes = classes
|
||||
self.distribute_zero_label = distribute_zero_label
|
||||
|
||||
|
@ -184,7 +184,7 @@ class QueryProvidedAttentionBlock(nn.Module):
|
|||
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||
self.num_heads = channels // num_head_channels
|
||||
self.norm = normalization(channels)
|
||||
self.q = bnb.nn.Linear8bitLt(channels, channels)
|
||||
self.q = ml.Linear(channels, channels)
|
||||
self.qnorm = nn.LayerNorm(channels)
|
||||
self.kv = conv_nd(1, channels, channels*2, 1)
|
||||
if use_new_attention_order:
|
||||
|
|
|
@ -3,7 +3,7 @@ import math
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
|
@ -45,7 +45,7 @@ class RandomLatentConverter(nn.Module):
|
|||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.layers = nn.Sequential(*[EqualLinear(channels, channels, lr_mul=.1) for _ in range(5)],
|
||||
bnb.nn.Linear8bitLt(channels, channels))
|
||||
ml.Linear(channels, channels))
|
||||
self.channels = channels
|
||||
|
||||
def forward(self, ref):
|
||||
|
|
|
@ -3,13 +3,13 @@ from librosa.filters import mel as librosa_mel_fn
|
|||
from models.audio.tts.tacotron2.audio_processing import dynamic_range_compression
|
||||
from models.audio.tts.tacotron2.audio_processing import dynamic_range_decompression
|
||||
from models.audio.tts.tacotron2.stft import STFT
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
|
||||
class LinearNorm(torch.nn.Module):
|
||||
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
|
||||
super(LinearNorm, self).__init__()
|
||||
self.linear_layer = torch.bnb.nn.Linear8bitLt(in_dim, out_dim, bias=bias)
|
||||
self.linear_layer = torch.ml.Linear(in_dim, out_dim, bias=bias)
|
||||
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.linear_layer.weight,
|
||||
|
|
|
@ -8,7 +8,7 @@ from models.audio.tts.tacotron2.layers import ConvNorm, LinearNorm
|
|||
from models.audio.tts.tacotron2.hparams import create_hparams
|
||||
from trainer.networks import register_model
|
||||
from models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
|
||||
class LocationLayer(nn.Module):
|
||||
|
@ -465,7 +465,7 @@ class Tacotron2(nn.Module):
|
|||
self.n_mel_channels = hparams.n_mel_channels
|
||||
self.n_frames_per_step = hparams.n_frames_per_step
|
||||
# nn.Embedding
|
||||
self.embedding = bnb.nn.StableEmbedding(
|
||||
self.embedding = ml.Embedding(
|
||||
hparams.n_symbols, hparams.symbols_embedding_dim)
|
||||
std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
|
||||
val = sqrt(3.0) * std # uniform bounds for std
|
||||
|
|
|
@ -13,7 +13,7 @@ from models.audio.tts.tacotron2.tacotron2 import Attention, Encoder
|
|||
from trainer.networks import register_model
|
||||
from models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths
|
||||
from utils.util import checkpoint
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
|
||||
|
||||
|
@ -187,7 +187,7 @@ class WaveTacotron2(nn.Module):
|
|||
self.n_mel_channels = hparams.n_mel_channels
|
||||
self.n_frames_per_step = hparams.n_frames_per_step
|
||||
# nn.Embedding
|
||||
self.embedding = bnb.nn.StableEmbedding(
|
||||
self.embedding = ml.Embedding(
|
||||
hparams.n_symbols, hparams.symbols_embedding_dim)
|
||||
std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
|
||||
val = sqrt(3.0) * std # uniform bounds for std
|
||||
|
|
|
@ -25,7 +25,7 @@ import random
|
|||
from time import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
|
@ -37,7 +37,7 @@ class LearnedPositionEmbeddings(nn.Module):
|
|||
def __init__(self, seq_len, model_dim, init=.02, relative=False):
|
||||
super().__init__()
|
||||
# nn.Embedding
|
||||
self.emb = bnb.nn.StableEmbedding(seq_len, model_dim)
|
||||
self.emb = ml.Embedding(seq_len, model_dim)
|
||||
# Initializing this way is standard for GPT-2
|
||||
self.emb.weight.data.normal_(mean=0.0, std=init)
|
||||
self.relative = relative
|
||||
|
|
|
@ -7,7 +7,7 @@ from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlo
|
|||
from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
|
||||
def is_latent(t):
|
||||
|
@ -21,7 +21,7 @@ class MultiGroupEmbedding(nn.Module):
|
|||
def __init__(self, tokens, groups, dim):
|
||||
super().__init__()
|
||||
# nn.Embedding
|
||||
self.m = nn.ModuleList([bnb.nn.StableEmbedding(tokens, dim // groups) for _ in range(groups)])
|
||||
self.m = nn.ModuleList([ml.Embedding(tokens, dim // groups) for _ in range(groups)])
|
||||
|
||||
def forward(self, x):
|
||||
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
|
||||
|
@ -102,9 +102,9 @@ class TransformerDiffusionTTS(nn.Module):
|
|||
ff_glu=True,
|
||||
rotary_pos_emb=True,
|
||||
)
|
||||
self.clvp_encoder = bnb.nn.Linear8bitLt(clvp_in_dim, model_channels)
|
||||
self.clvp_encoder = ml.Linear(clvp_in_dim, model_channels)
|
||||
# nn.Embedding
|
||||
self.type_embedding = bnb.nn.StableEmbedding(types, model_channels)
|
||||
self.type_embedding = ml.Embedding(types, model_channels)
|
||||
|
||||
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
|
||||
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
|
||||
|
@ -112,7 +112,7 @@ class TransformerDiffusionTTS(nn.Module):
|
|||
# transformer network.
|
||||
if in_groups is None:
|
||||
# nn.Embedding
|
||||
self.embeddings = bnb.nn.StableEmbedding(token_count, model_channels)
|
||||
self.embeddings = ml.Embedding(token_count, model_channels)
|
||||
else:
|
||||
self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels)
|
||||
self.latent_conditioner = nn.Sequential(
|
||||
|
@ -144,7 +144,7 @@ class TransformerDiffusionTTS(nn.Module):
|
|||
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
|
||||
|
||||
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
|
||||
self.intg = bnb.nn.Linear8bitLt(model_channels*2, model_channels)
|
||||
self.intg = ml.Linear(model_channels*2, model_channels)
|
||||
self.layers = TimestepRotaryEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers)])
|
||||
|
||||
self.out = nn.Sequential(
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock
|
||||
|
@ -21,7 +21,7 @@ class MultiGroupEmbedding(nn.Module):
|
|||
def __init__(self, tokens, groups, dim):
|
||||
super().__init__()
|
||||
# nn.Embedding
|
||||
self.m = nn.ModuleList([bnb.nn.StableEmbedding(tokens, dim // groups) for _ in range(groups)])
|
||||
self.m = nn.ModuleList([ml.Embedding(tokens, dim // groups) for _ in range(groups)])
|
||||
|
||||
def forward(self, x):
|
||||
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
|
||||
|
@ -42,7 +42,7 @@ class DietAttentionBlock(TimestepBlock):
|
|||
def __init__(self, in_dim, dim, heads, dropout):
|
||||
super().__init__()
|
||||
self.rms_scale_norm = RMSScaleShiftNorm(in_dim)
|
||||
self.proj = bnb.nn.Linear8bitLt(in_dim, dim)
|
||||
self.proj = ml.Linear(in_dim, dim)
|
||||
self.attn = Attention(dim, heads=heads, causal=False, dropout=dropout)
|
||||
self.ff = FeedForward(dim, in_dim, mult=1, dropout=dropout, zero_init_output=True)
|
||||
|
||||
|
@ -107,9 +107,9 @@ class TransformerDiffusionTTS(nn.Module):
|
|||
ff_glu=True,
|
||||
rotary_pos_emb=True,
|
||||
)
|
||||
self.clvp_encoder = bnb.nn.Linear8bitLt(clvp_in_dim, prenet_channels)
|
||||
self.clvp_encoder = ml.Linear(clvp_in_dim, prenet_channels)
|
||||
# nn.Embedding
|
||||
self.type_embedding = bnb.nn.StableEmbedding(types, prenet_channels)
|
||||
self.type_embedding = ml.Embedding(types, prenet_channels)
|
||||
|
||||
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
|
||||
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
|
||||
|
@ -117,7 +117,7 @@ class TransformerDiffusionTTS(nn.Module):
|
|||
# transformer network.
|
||||
if in_groups is None:
|
||||
# nn.Embedding
|
||||
self.embeddings = bnb.nn.StableEmbedding(token_count, prenet_channels)
|
||||
self.embeddings = ml.Embedding(token_count, prenet_channels)
|
||||
else:
|
||||
self.embeddings = MultiGroupEmbedding(token_count, in_groups, prenet_channels)
|
||||
self.latent_conditioner = nn.Sequential(
|
||||
|
@ -148,8 +148,8 @@ class TransformerDiffusionTTS(nn.Module):
|
|||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels))
|
||||
|
||||
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
|
||||
self.cond_intg = bnb.nn.Linear8bitLt(prenet_channels*4, model_channels)
|
||||
self.intg = bnb.nn.Linear8bitLt(prenet_channels*2, model_channels)
|
||||
self.cond_intg = ml.Linear(prenet_channels*4, model_channels)
|
||||
self.intg = ml.Linear(prenet_channels*2, model_channels)
|
||||
|
||||
self.layers = TimestepRotaryEmbedSequential(*[DietAttentionBlock(model_channels, block_channels, block_channels // 64, dropout) for _ in range(num_layers)])
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import autocast
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \
|
||||
|
@ -249,7 +249,7 @@ class DiffusionTts(nn.Module):
|
|||
|
||||
embedding_dim = model_channels * 8
|
||||
# nn.Embedding
|
||||
self.code_embedding = bnb.nn.StableEmbedding(num_tokens+1, embedding_dim)
|
||||
self.code_embedding = ml.Embedding(num_tokens+1, embedding_dim)
|
||||
self.contextual_embedder = AudioMiniEncoder(1, embedding_dim, base_channels=32, depth=6, resnet_blocks=1,
|
||||
attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
|
||||
self.conditioning_conv = nn.Conv1d(embedding_dim*3, embedding_dim, 1)
|
||||
|
@ -257,7 +257,7 @@ class DiffusionTts(nn.Module):
|
|||
self.enable_unaligned_inputs = enabled_unaligned_inputs
|
||||
if enabled_unaligned_inputs:
|
||||
# nn.Embedding
|
||||
self.unaligned_embedder = bnb.nn.StableEmbedding(num_unaligned_tokens, embedding_dim)
|
||||
self.unaligned_embedder = ml.Embedding(num_unaligned_tokens, embedding_dim)
|
||||
self.unaligned_encoder = CheckpointedXTransformerEncoder(
|
||||
max_seq_len=-1,
|
||||
use_pos_emb=False,
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from torch import autocast
|
||||
from x_transformers import Encoder
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \
|
||||
|
@ -208,7 +208,7 @@ class DiffusionTts(nn.Module):
|
|||
# transformer network.
|
||||
self.code_converter = nn.Sequential(
|
||||
# nn.Embedding
|
||||
bnb.nn.StableEmbedding(in_tokens, conditioning_dim),
|
||||
ml.Embedding(in_tokens, conditioning_dim),
|
||||
CheckpointedXTransformerEncoder(
|
||||
needs_permute=False,
|
||||
max_seq_len=-1,
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import autocast
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock, QKVAttentionLegacy
|
||||
|
@ -196,7 +196,7 @@ class DiffusionTtsFlat(nn.Module):
|
|||
# transformer network.
|
||||
|
||||
# nn.Embedding
|
||||
self.code_embedding = bnb.nn.StableEmbedding(in_tokens, model_channels)
|
||||
self.code_embedding = ml.Embedding(in_tokens, model_channels)
|
||||
self.code_converter = nn.Sequential(
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
|
|
|
@ -13,7 +13,7 @@ from models.lucidrains.x_transformers import RotaryEmbedding, apply_rotary_pos_e
|
|||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
"""
|
||||
|
@ -282,10 +282,10 @@ class UnifiedVoice(nn.Module):
|
|||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
||||
self.average_conditioning_embeddings = average_conditioning_embeddings
|
||||
# nn.Embedding
|
||||
self.text_embedding = bnb.nn.StableEmbedding(self.number_text_tokens, model_dim)
|
||||
self.text_embedding = ml.Embedding(self.number_text_tokens, model_dim)
|
||||
if use_mel_codes_as_input:
|
||||
# nn.Embedding
|
||||
self.mel_embedding = bnb.nn.StableEmbedding(self.number_mel_codes, model_dim)
|
||||
self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim)
|
||||
else:
|
||||
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
|
||||
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
|
||||
|
@ -298,8 +298,8 @@ class UnifiedVoice(nn.Module):
|
|||
self.text_solo_embedding = 0
|
||||
|
||||
self.final_norm = nn.LayerNorm(model_dim)
|
||||
self.text_head = bnb.nn.Linear8bitLt(model_dim, self.number_text_tokens)
|
||||
self.mel_head = bnb.nn.Linear8bitLt(model_dim, self.number_mel_codes)
|
||||
self.text_head = ml.Linear(model_dim, self.number_text_tokens)
|
||||
self.mel_head = ml.Linear(model_dim, self.number_mel_codes)
|
||||
|
||||
# Initialize the embeddings per the GPT-2 scheme
|
||||
embeddings = [self.text_embedding]
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
from transformers import GPT2Config, GPT2PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||
|
@ -274,16 +274,16 @@ class UnifiedVoice(nn.Module):
|
|||
self.mel_length_compression = mel_length_compression
|
||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
||||
# nn.Embedding
|
||||
self.text_embedding = bnb.nn.StableEmbedding(self.number_text_tokens*types+1, model_dim)
|
||||
self.text_embedding = ml.Embedding(self.number_text_tokens*types+1, model_dim)
|
||||
# nn.Embedding
|
||||
self.mel_embedding = bnb.nn.StableEmbedding(self.number_mel_codes, model_dim)
|
||||
self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim)
|
||||
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
|
||||
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing)
|
||||
|
||||
self.final_norm = nn.LayerNorm(model_dim)
|
||||
self.text_head = bnb.nn.Linear8bitLt(model_dim, self.number_text_tokens*types+1)
|
||||
self.mel_head = bnb.nn.Linear8bitLt(model_dim, self.number_mel_codes)
|
||||
self.aligned_head = bnb.nn.Linear8bitLt(model_dim, number_aligned_text_codes)
|
||||
self.text_head = ml.Linear(model_dim, self.number_text_tokens*types+1)
|
||||
self.mel_head = ml.Linear(model_dim, self.number_mel_codes)
|
||||
self.aligned_head = ml.Linear(model_dim, number_aligned_text_codes)
|
||||
|
||||
# Initialize the embeddings per the GPT-2 scheme
|
||||
embeddings = [self.text_embedding, self.mel_embedding]
|
||||
|
|
|
@ -11,7 +11,7 @@ from models.audio.tts.transformer_builders import build_hf_gpt_transformer
|
|||
from models.lucidrains.x_transformers import RotaryEmbedding, apply_rotary_pos_emb
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
|
@ -257,16 +257,16 @@ class UnifiedVoice(nn.Module):
|
|||
self.mel_length_compression = mel_length_compression
|
||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
||||
# nn.Embedding
|
||||
self.text_embedding = bnb.nn.StableEmbedding(self.number_text_tokens*types+1, model_dim)
|
||||
self.text_embedding = ml.Embedding(self.number_text_tokens*types+1, model_dim)
|
||||
# nn.Embedding
|
||||
self.mel_embedding = bnb.nn.StableEmbedding(self.number_mel_codes, model_dim)
|
||||
self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim)
|
||||
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
|
||||
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing)
|
||||
|
||||
self.final_norm = nn.LayerNorm(model_dim)
|
||||
self.text_head = bnb.nn.Linear8bitLt(model_dim, self.number_text_tokens*types+1)
|
||||
self.mel_head = bnb.nn.Linear8bitLt(model_dim, self.number_mel_codes)
|
||||
self.alignment_head = bnb.nn.Linear8bitLt(model_dim, 256)
|
||||
self.text_head = ml.Linear(model_dim, self.number_text_tokens*types+1)
|
||||
self.mel_head = ml.Linear(model_dim, self.number_mel_codes)
|
||||
self.alignment_head = ml.Linear(model_dim, 256)
|
||||
|
||||
if only_alignment_head:
|
||||
for p in self.parameters():
|
||||
|
|
|
@ -8,7 +8,7 @@ from models.audio.tts.mini_encoder import AudioMiniEncoder
|
|||
from trainer.injectors.spec_augment import spec_augment
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
|
||||
def exists(val):
|
||||
|
@ -37,7 +37,7 @@ class VoiceCLIP(nn.Module):
|
|||
self.encoder = AudioMiniEncoder(80, encoder_output)
|
||||
if pretrained_encoder_dict_path is not None:
|
||||
self.encoder.load_state_dict(torch.load(pretrained_encoder_dict_path))
|
||||
self.to_latent = bnb.nn.Linear8bitLt(encoder_output, dim_latent, bias=False)
|
||||
self.to_latent = ml.Linear(encoder_output, dim_latent, bias=False)
|
||||
self.temperature = nn.Parameter(torch.tensor(1.))
|
||||
self.mel_compression_ratio = mel_compression_ratio
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from x_transformers import Encoder, Decoder, ContinuousTransformerWrapper
|
|||
|
||||
from models.audio.tts.mini_encoder import AudioMiniEncoder
|
||||
from trainer.networks import register_model
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
|
||||
class CheckpointedLayer(nn.Module):
|
||||
|
@ -58,7 +58,7 @@ class Wav2VecMatcher(nn.Module):
|
|||
self.conditioning_encoder = AudioMiniEncoder(1, model_dim, base_channels=32, depth=6, resnet_blocks=1,
|
||||
attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
|
||||
# nn.Embedding
|
||||
self.text_embedding = bnb.nn.StableEmbedding(num_text_tokens, model_dim)
|
||||
self.text_embedding = ml.Embedding(num_text_tokens, model_dim)
|
||||
self.encoder = CheckpointedXTransformer(
|
||||
max_seq_len=-1,
|
||||
use_pos_emb=False,
|
||||
|
@ -75,8 +75,8 @@ class Wav2VecMatcher(nn.Module):
|
|||
)
|
||||
self.decoder_start_embedding = nn.Parameter(torch.randn(1,1,model_dim))
|
||||
self.decoder_stop_embedding = nn.Parameter(torch.randn(1,model_dim))
|
||||
self.w2v_query_encoder = bnb.nn.Linear8bitLt(WAV2VEC_CHANNELS, model_dim)
|
||||
self.w2v_value_encoder = bnb.nn.Linear8bitLt(WAV2VEC_CHANNELS, model_dim)
|
||||
self.w2v_query_encoder = ml.Linear(WAV2VEC_CHANNELS, model_dim)
|
||||
self.w2v_value_encoder = ml.Linear(WAV2VEC_CHANNELS, model_dim)
|
||||
self.decoder = CheckpointedXTransformer(
|
||||
max_seq_len=-1, # Should be unused
|
||||
use_pos_emb=False,
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
from trainer.networks import register_model
|
||||
|
||||
|
@ -99,7 +99,7 @@ class ResNet(nn.Module):
|
|||
self.conv4_x = self._make_layer(block, 128, num_block[2], 2)
|
||||
self.conv5_x = self._make_layer(block, 256, num_block[3], 2)
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = bnb.nn.Linear8bitLt(256 * block.expansion, num_classes)
|
||||
self.fc = ml.Linear(256 * block.expansion, num_classes)
|
||||
|
||||
def _make_layer(self, block, out_channels, num_blocks, stride):
|
||||
"""make resnet layers(by layer i didnt mean this 'layer' was the
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from torchvision.models.resnet import BasicBlock, Bottleneck
|
||||
import torchvision
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
|
||||
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
||||
|
@ -195,5 +195,5 @@ def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
|
|||
def register_resnet50(opt_net, opt):
|
||||
model = resnet50(pretrained=opt_net['pretrained'])
|
||||
if opt_net['custom_head_logits']:
|
||||
model.fc = bnb.nn.Linear8bitLt(512 * 4, opt_net['custom_head_logits'])
|
||||
model.fc = ml.Linear(512 * 4, opt_net['custom_head_logits'])
|
||||
return model
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
from trainer.networks import register_model
|
||||
|
||||
|
@ -102,7 +102,7 @@ class ResNet(nn.Module):
|
|||
self.conv4_x = self._make_layer(block, 128, num_block[2], 2)
|
||||
self.conv5_x = self._make_layer(block, 256, num_block[3], 2)
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = bnb.nn.Linear8bitLt(256 * block.expansion, num_classes)
|
||||
self.fc = ml.Linear(256 * block.expansion, num_classes)
|
||||
|
||||
def _make_layer(self, block, out_channels, num_blocks, stride):
|
||||
"""make resnet layers(by layer i didnt mean this 'layer' was the
|
||||
|
|
|
@ -11,7 +11,7 @@ __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
|||
from models.vqvae.scaled_weight_conv import ScaledWeightConv
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
model_urls = {
|
||||
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
|
@ -214,7 +214,7 @@ class ResNet(nn.Module):
|
|||
self.layer4 = self._make_layer(block, 512, layers[3], breadth, stride=2,
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = bnb.nn.Linear8bitLt(512 * block.expansion, num_classes)
|
||||
self.fc = ml.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, ScaledWeightConv):
|
||||
|
|
|
@ -3,7 +3,7 @@ import torch.nn as nn
|
|||
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
class WideKernelVgg(nn.Module):
|
||||
def __init__(self, nf=64, num_classes=2):
|
||||
|
@ -49,9 +49,9 @@ class WideKernelVgg(nn.Module):
|
|||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2),
|
||||
nn.Flatten(),
|
||||
bnb.nn.Linear8bitLt(nf * 8 * 4 * 2, 100),
|
||||
ml.Linear(nf * 8 * 4 * 2, 100),
|
||||
nn.ReLU(),
|
||||
bnb.nn.Linear8bitLt(100, num_classes)
|
||||
ml.Linear(100, num_classes)
|
||||
)
|
||||
|
||||
# These normalization constants should be derived experimentally.
|
||||
|
|
|
@ -10,7 +10,7 @@ from models.arch_util import AttentionBlock
|
|||
from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get, checkpoint
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
|
||||
def exists(val):
|
||||
|
@ -60,7 +60,7 @@ class ConvFormatEmbedding(nn.Module):
|
|||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
# nn.Embedding
|
||||
self.emb = bnb.nn.StableEmbedding(*args, **kwargs)
|
||||
self.emb = ml.Embedding(*args, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.emb(x)
|
||||
|
@ -101,9 +101,9 @@ class CLVP(nn.Module):
|
|||
self.mask_conditioning_percentage = mask_conditioning_percentage
|
||||
|
||||
# nn.Embedding
|
||||
self.text_emb = bnb.nn.StableEmbedding(num_text_tokens, model_dim)
|
||||
self.text_emb = ml.Embedding(num_text_tokens, model_dim)
|
||||
self.text_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, text_enc_depth, text_mask_percentage, use_rms_scaleshift_norm=True)
|
||||
self.to_text_latent = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False)
|
||||
self.to_text_latent = ml.Linear(latent_dim, latent_dim, bias=False)
|
||||
self.distributed_collect = distributed_collect
|
||||
|
||||
if mel_codes is None:
|
||||
|
@ -111,7 +111,7 @@ class CLVP(nn.Module):
|
|||
else:
|
||||
self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim)
|
||||
self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage)
|
||||
self.to_speech_latent = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False)
|
||||
self.to_speech_latent = ml.Linear(latent_dim, latent_dim, bias=False)
|
||||
|
||||
def get_grad_norm_parameter_groups(self):
|
||||
return {
|
||||
|
|
|
@ -9,7 +9,7 @@ from models.arch_util import AttentionBlock
|
|||
from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get, checkpoint
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
|
||||
def exists(val):
|
||||
|
@ -180,7 +180,7 @@ class ConvFormatEmbedding(nn.Module):
|
|||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
# nn.Embedding
|
||||
self.emb = bnb.nn.StableEmbedding(*args, **kwargs)
|
||||
self.emb = ml.Embedding(*args, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.emb(x)
|
||||
|
@ -205,8 +205,8 @@ class ContrastiveAudio(nn.Module):
|
|||
self.emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim // 2, kernel_size=5, stride=2, padding=2),
|
||||
nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1))
|
||||
self.transformer = CollapsingTransformer(model_dim, model_dim, transformer_heads, dropout, encoder_depth, mask_percent)
|
||||
self.to_latent = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False)
|
||||
self.to_latent2 = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False)
|
||||
self.to_latent = ml.Linear(latent_dim, latent_dim, bias=False)
|
||||
self.to_latent2 = ml.Linear(latent_dim, latent_dim, bias=False)
|
||||
|
||||
self.to_latent2.weight.data = self.to_latent.weight.data
|
||||
self.to_latent2.weight.DO_NOT_TRAIN = True
|
||||
|
|
|
@ -10,7 +10,7 @@ from models.arch_util import AttentionBlock
|
|||
from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get, checkpoint
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
|
||||
def exists(val):
|
||||
|
@ -60,7 +60,7 @@ class ConvFormatEmbedding(nn.Module):
|
|||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
# nn.Embedding
|
||||
self.emb = bnb.nn.StableEmbedding(*args, **kwargs)
|
||||
self.emb = ml.Embedding(*args, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.emb(x)
|
||||
|
@ -88,14 +88,14 @@ class CVVP(nn.Module):
|
|||
self.cond_emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim//2, kernel_size=5, stride=2, padding=2),
|
||||
nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1))
|
||||
self.conditioning_transformer = CollapsingTransformer(model_dim, model_dim, transformer_heads, dropout, conditioning_enc_depth, cond_mask_percentage)
|
||||
self.to_conditioning_latent = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False)
|
||||
self.to_conditioning_latent = ml.Linear(latent_dim, latent_dim, bias=False)
|
||||
|
||||
if mel_codes is None:
|
||||
self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2)
|
||||
else:
|
||||
self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim)
|
||||
self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage)
|
||||
self.to_speech_latent = bnb.nn.Linear8bitLt(latent_dim, latent_dim, bias=False)
|
||||
self.to_speech_latent = ml.Linear(latent_dim, latent_dim, bias=False)
|
||||
|
||||
def get_grad_norm_parameter_groups(self):
|
||||
return {
|
||||
|
|
|
@ -7,7 +7,7 @@ from torch import einsum
|
|||
from models.lucidrains.dalle.transformer import Transformer
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
|
||||
def exists(val):
|
||||
|
@ -47,19 +47,19 @@ class MelTextCLIP(nn.Module):
|
|||
):
|
||||
super().__init__()
|
||||
# nn.Embedding
|
||||
self.text_emb = bnb.nn.StableEmbedding(num_text_tokens, dim_text)
|
||||
self.text_emb = ml.Embedding(num_text_tokens, dim_text)
|
||||
# nn.Embedding
|
||||
self.text_pos_emb = bnb.nn.StableEmbedding(text_seq_len, dim_text)
|
||||
self.text_pos_emb = ml.Embedding(text_seq_len, dim_text)
|
||||
self.text_transformer = Transformer(causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth,
|
||||
heads=text_heads, rotary_emb=False)
|
||||
self.to_text_latent = bnb.nn.Linear8bitLt(dim_text, dim_latent, bias=False)
|
||||
self.to_text_latent = ml.Linear(dim_text, dim_latent, bias=False)
|
||||
|
||||
self.speech_enc = nn.Conv1d(80, dim_speech, kernel_size=3, padding=1)
|
||||
# nn.Embedding
|
||||
self.speech_pos_emb = bnb.nn.StableEmbedding(num_speech_tokens, dim_speech)
|
||||
self.speech_pos_emb = ml.Embedding(num_speech_tokens, dim_speech)
|
||||
self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech,
|
||||
depth=speech_enc_depth, heads=speech_heads, rotary_emb=False)
|
||||
self.to_speech_latent = bnb.nn.Linear8bitLt(dim_speech, dim_latent, bias=False)
|
||||
self.to_speech_latent = ml.Linear(dim_speech, dim_latent, bias=False)
|
||||
|
||||
self.temperature = nn.Parameter(torch.tensor(1.))
|
||||
self.text_mask_percentage = text_mask_percentage
|
||||
|
|
|
@ -7,7 +7,7 @@ from models.audio.tts.unified_voice2 import ConditioningEncoder
|
|||
from models.lucidrains.dalle.transformer import Transformer
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
|
||||
def exists(val):
|
||||
|
@ -46,7 +46,7 @@ class VoiceCondCLIP(nn.Module):
|
|||
self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
|
||||
self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech,
|
||||
depth=speech_enc_depth, heads=speech_heads, rotary_emb=False)
|
||||
self.to_speech_latent = bnb.nn.Linear8bitLt(dim_speech, dim_latent, bias=False)
|
||||
self.to_speech_latent = ml.Linear(dim_speech, dim_latent, bias=False)
|
||||
|
||||
self.temperature = nn.Parameter(torch.tensor(1.))
|
||||
self.voice_mask_percentage = voice_mask_percentage
|
||||
|
|
|
@ -11,7 +11,7 @@ from models.audio.tts.unet_diffusion_tts7 import CheckpointedXTransformerEncoder
|
|||
from models.lucidrains.dalle.transformer import Transformer
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
|
||||
def exists(val):
|
||||
|
@ -55,12 +55,12 @@ class VoiceCLIP(nn.Module):
|
|||
):
|
||||
super().__init__()
|
||||
# nn.Embedding
|
||||
self.text_emb = bnb.nn.StableEmbedding(num_text_tokens, dim_text)
|
||||
self.to_text_latent = bnb.nn.Linear8bitLt(dim_text, dim_latent, bias=False)
|
||||
self.text_emb = ml.Embedding(num_text_tokens, dim_text)
|
||||
self.to_text_latent = ml.Linear(dim_text, dim_latent, bias=False)
|
||||
|
||||
# nn.Embedding
|
||||
self.speech_emb = bnb.nn.StableEmbedding(num_speech_tokens, dim_speech)
|
||||
self.to_speech_latent = bnb.nn.Linear8bitLt(dim_speech, dim_latent, bias=False)
|
||||
self.speech_emb = ml.Embedding(num_speech_tokens, dim_speech)
|
||||
self.to_speech_latent = ml.Linear(dim_speech, dim_latent, bias=False)
|
||||
|
||||
if use_xformers:
|
||||
self.text_transformer = CheckpointedXTransformerEncoder(
|
||||
|
@ -109,9 +109,9 @@ class VoiceCLIP(nn.Module):
|
|||
self.distributed_collect = distributed_collect
|
||||
if not use_xformers:
|
||||
# nn.Embedding
|
||||
self.text_pos_emb = bnb.nn.StableEmbedding(text_seq_len, dim_text)
|
||||
self.text_pos_emb = ml.Embedding(text_seq_len, dim_text)
|
||||
# nn.Embedding
|
||||
self.speech_pos_emb = bnb.nn.StableEmbedding(num_speech_tokens, dim_speech)
|
||||
self.speech_pos_emb = ml.Embedding(num_speech_tokens, dim_speech)
|
||||
|
||||
def embed_text(self, text):
|
||||
text_mask = torch.ones_like(text.float()).bool()
|
||||
|
|
|
@ -6,7 +6,7 @@ import math
|
|||
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
|
||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||
|
@ -37,7 +37,7 @@ def linear(*args, **kwargs):
|
|||
"""
|
||||
Create a linear module.
|
||||
"""
|
||||
return bnb.nn.Linear8bitLt(*args, **kwargs)
|
||||
return ml.Linear(*args, **kwargs)
|
||||
|
||||
|
||||
def avg_pool_nd(dims, *args, **kwargs):
|
||||
|
|
|
@ -6,7 +6,7 @@ from models.arch_util import ConvGnLelu, default_init_weights, make_layer
|
|||
from models.diffusion.nn import timestep_embedding
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
|
||||
# Conditionally uses torch's checkpoint functionality if it is enabled in the opt file.
|
||||
|
@ -29,7 +29,7 @@ class ResidualDenseBlock(nn.Module):
|
|||
self.first_conv = ConvGnLelu(mid_channels, mid_channels, activation=True, norm=False, bias=True)
|
||||
self.emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
bnb.nn.Linear8bitLt(
|
||||
ml.Linear(
|
||||
mid_channels*4,
|
||||
mid_channels,
|
||||
),
|
||||
|
@ -144,9 +144,9 @@ class RRDBNet(nn.Module):
|
|||
# Guided diffusion uses a time embedding.
|
||||
time_embed_dim = mid_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
bnb.nn.Linear8bitLt(mid_channels, time_embed_dim),
|
||||
ml.Linear(mid_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
bnb.nn.Linear8bitLt(time_embed_dim, time_embed_dim),
|
||||
ml.Linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
self.body = make_layer(
|
||||
|
|
|
@ -20,7 +20,7 @@ from models.diffusion.nn import (
|
|||
)
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
|
||||
class AttentionPool2d(nn.Module):
|
||||
|
@ -517,7 +517,7 @@ class UNetModel(nn.Module):
|
|||
|
||||
if self.num_classes is not None:
|
||||
# nn.Embedding
|
||||
self.label_emb = bnb.nn.StableEmbedding(num_classes, time_embed_dim)
|
||||
self.label_emb = ml.Embedding(num_classes, time_embed_dim)
|
||||
self.use_raw_y_as_embedding = use_raw_y_as_embedding
|
||||
assert not ((self.num_classes is not None) and use_raw_y_as_embedding) # These are mutually-exclusive.
|
||||
|
||||
|
@ -869,16 +869,16 @@ class EncoderUNetModel(nn.Module):
|
|||
)
|
||||
elif pool == "spatial":
|
||||
self.out = nn.Sequential(
|
||||
bnb.nn.Linear8bitLt(self._feature_size, 2048),
|
||||
ml.Linear(self._feature_size, 2048),
|
||||
nn.ReLU(),
|
||||
bnb.nn.Linear8bitLt(2048, self.out_channels),
|
||||
ml.Linear(2048, self.out_channels),
|
||||
)
|
||||
elif pool == "spatial_v2":
|
||||
self.out = nn.Sequential(
|
||||
bnb.nn.Linear8bitLt(self._feature_size, 2048),
|
||||
ml.Linear(self._feature_size, 2048),
|
||||
normalization(2048),
|
||||
nn.SiLU(),
|
||||
bnb.nn.Linear8bitLt(2048, self.out_channels),
|
||||
ml.Linear(2048, self.out_channels),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unexpected {pool} pooling")
|
||||
|
|
|
@ -26,7 +26,7 @@ from models.diffusion.nn import (
|
|||
)
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
|
||||
class AttentionPool2d(nn.Module):
|
||||
|
@ -478,7 +478,7 @@ class UNetModel(nn.Module):
|
|||
|
||||
if self.num_classes is not None:
|
||||
# nn.Embedding
|
||||
self.label_emb = bnb.nn.StableEmbedding(num_classes, time_embed_dim)
|
||||
self.label_emb = ml.Embedding(num_classes, time_embed_dim)
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
|
@ -738,7 +738,7 @@ class ResNetEncoder(nn.Module):
|
|||
dilate=replace_stride_with_dilation[2])
|
||||
f=512
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = bnb.nn.Linear8bitLt(f * block.expansion, output_dim)
|
||||
self.fc = ml.Linear(f * block.expansion, output_dim)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch.nn.functional as F
|
|||
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, opt_get
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
|
||||
class Discriminator_VGG_128(nn.Module):
|
||||
|
@ -47,8 +47,8 @@ class Discriminator_VGG_128(nn.Module):
|
|||
input_img_factor = input_img_factor // 2
|
||||
final_nf = nf * 16
|
||||
|
||||
self.linear1 = bnb.nn.Linear8bitLt(final_nf * 4 * input_img_factor * 4 * input_img_factor, 100)
|
||||
self.linear2 = bnb.nn.Linear8bitLt(100, 1)
|
||||
self.linear1 = ml.Linear(final_nf * 4 * input_img_factor * 4 * input_img_factor, 100)
|
||||
self.linear2 = ml.Linear(100, 1)
|
||||
|
||||
# activation function
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
@ -130,8 +130,8 @@ class Discriminator_VGG_128_GN(nn.Module):
|
|||
# activation function
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
self.linear1 = bnb.nn.Linear8bitLt(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 100)
|
||||
self.linear2 = bnb.nn.Linear8bitLt(100, 1)
|
||||
self.linear1 = ml.Linear(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 100)
|
||||
self.linear2 = ml.Linear(100, 1)
|
||||
|
||||
def compute_body(self, x):
|
||||
fea = self.lrelu(self.conv0_0(x))
|
||||
|
@ -220,8 +220,8 @@ class DiscriminatorVGG448GN(nn.Module):
|
|||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
final_nf = nf * 8
|
||||
self.linear1 = bnb.nn.Linear8bitLt(int(final_nf * 7 * 7), 100)
|
||||
self.linear2 = bnb.nn.Linear8bitLt(100, 1)
|
||||
self.linear1 = ml.Linear(int(final_nf * 7 * 7), 100)
|
||||
self.linear2 = ml.Linear(100, 1)
|
||||
|
||||
# Assign all new heads to the new param group.2
|
||||
for m in [self.convn1_0, self.convn1_1, self.bnn1_1, self.conv0_0_new, self.bn0_0]:
|
||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.init as init
|
||||
import torch.nn.functional as F
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
|
||||
def initialize_weights(net_l, scale=1):
|
||||
|
@ -15,7 +15,7 @@ def initialize_weights(net_l, scale=1):
|
|||
m.weight.data *= scale # for residual block
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, bnb.nn.Linear8bitLt):
|
||||
elif isinstance(m, ml.Linear):
|
||||
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||
m.weight.data *= scale
|
||||
if m.bias is not None:
|
||||
|
|
|
@ -28,7 +28,7 @@ except:
|
|||
APEX_AVAILABLE = False
|
||||
|
||||
assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.'
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
num_cores = multiprocessing.cpu_count()
|
||||
|
||||
|
@ -352,7 +352,7 @@ class RGBBlock(nn.Module):
|
|||
def __init__(self, latent_dim, input_channel, upsample, rgba=False):
|
||||
super().__init__()
|
||||
self.input_channel = input_channel
|
||||
self.to_style = bnb.nn.Linear8bitLt(latent_dim, input_channel)
|
||||
self.to_style = ml.Linear(latent_dim, input_channel)
|
||||
|
||||
out_filters = 3 if not rgba else 4
|
||||
self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False)
|
||||
|
@ -490,16 +490,16 @@ class GeneratorBlockWithStructure(nn.Module):
|
|||
|
||||
# Uses stylegan1 style blocks for injecting structural latent.
|
||||
self.conv0 = EqualConv2d(input_channels, filters, 3, padding=1)
|
||||
self.to_noise0 = bnb.nn.Linear8bitLt(1, filters)
|
||||
self.to_noise0 = ml.Linear(1, filters)
|
||||
self.noise0 = equal_lr(NoiseInjection(filters))
|
||||
self.adain0 = AdaptiveInstanceNorm(filters, latent_dim)
|
||||
|
||||
self.to_style1 = bnb.nn.Linear8bitLt(latent_dim, filters)
|
||||
self.to_noise1 = bnb.nn.Linear8bitLt(1, filters)
|
||||
self.to_style1 = ml.Linear(latent_dim, filters)
|
||||
self.to_noise1 = ml.Linear(1, filters)
|
||||
self.conv1 = Conv2DMod(filters, filters, 3)
|
||||
|
||||
self.to_style2 = bnb.nn.Linear8bitLt(latent_dim, filters)
|
||||
self.to_noise2 = bnb.nn.Linear8bitLt(1, filters)
|
||||
self.to_style2 = ml.Linear(latent_dim, filters)
|
||||
self.to_noise2 = ml.Linear(1, filters)
|
||||
self.conv2 = Conv2DMod(filters, filters, 3)
|
||||
|
||||
self.activation = leaky_relu()
|
||||
|
@ -541,12 +541,12 @@ class GeneratorBlock(nn.Module):
|
|||
self.structure_conv = nn.Conv2d(3, input_channels, 3, padding=1)
|
||||
input_channels = input_channels * 2
|
||||
|
||||
self.to_style1 = bnb.nn.Linear8bitLt(latent_dim, input_channels)
|
||||
self.to_noise1 = bnb.nn.Linear8bitLt(1, filters)
|
||||
self.to_style1 = ml.Linear(latent_dim, input_channels)
|
||||
self.to_noise1 = ml.Linear(1, filters)
|
||||
self.conv1 = Conv2DMod(input_channels, filters, 3)
|
||||
|
||||
self.to_style2 = bnb.nn.Linear8bitLt(latent_dim, filters)
|
||||
self.to_noise2 = bnb.nn.Linear8bitLt(1, filters)
|
||||
self.to_style2 = ml.Linear(latent_dim, filters)
|
||||
self.to_noise2 = ml.Linear(1, filters)
|
||||
self.conv2 = Conv2DMod(filters, filters, 3)
|
||||
|
||||
self.activation = leaky_relu()
|
||||
|
@ -725,7 +725,7 @@ class StyleGan2GeneratorWithLatent(nn.Module):
|
|||
|
||||
def _init_weights(self):
|
||||
for m in self.modules():
|
||||
if type(m) in {nn.Conv2d, bnb.nn.Linear8bitLt} and hasattr(m, 'weight'):
|
||||
if type(m) in {nn.Conv2d, ml.Linear} and hasattr(m, 'weight'):
|
||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
||||
|
||||
for block in self.gen.blocks:
|
||||
|
@ -805,7 +805,7 @@ class StyleGan2Discriminator(nn.Module):
|
|||
|
||||
self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1)
|
||||
self.flatten = Flatten()
|
||||
self.to_logit = bnb.nn.Linear8bitLt(latent_dim, 1)
|
||||
self.to_logit = ml.Linear(latent_dim, 1)
|
||||
|
||||
self._init_weights()
|
||||
|
||||
|
@ -837,7 +837,7 @@ class StyleGan2Discriminator(nn.Module):
|
|||
|
||||
def _init_weights(self):
|
||||
for m in self.modules():
|
||||
if type(m) in {nn.Conv2d, bnb.nn.Linear8bitLt}:
|
||||
if type(m) in {nn.Conv2d, ml.Linear}:
|
||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
||||
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ from torch import nn
|
|||
from data.images.byol_attachment import RandomApply
|
||||
from trainer.networks import register_model, create_model
|
||||
from utils.util import checkpoint, opt_get
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
|
||||
def default(val, def_val):
|
||||
|
@ -79,10 +79,10 @@ class MLP(nn.Module):
|
|||
def __init__(self, dim, projection_size, hidden_size=4096):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
bnb.nn.Linear8bitLt(dim, hidden_size),
|
||||
ml.Linear(dim, hidden_size),
|
||||
nn.BatchNorm1d(hidden_size),
|
||||
nn.ReLU(inplace=True),
|
||||
bnb.nn.Linear8bitLt(hidden_size, projection_size)
|
||||
ml.Linear(hidden_size, projection_size)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -104,10 +104,10 @@ class StructuralMLP(nn.Module):
|
|||
nn.BatchNorm2d(c),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Flatten(),
|
||||
bnb.nn.Linear8bitLt(flattened_dim, hidden_size),
|
||||
ml.Linear(flattened_dim, hidden_size),
|
||||
nn.BatchNorm1d(hidden_size),
|
||||
nn.ReLU(inplace=True),
|
||||
bnb.nn.Linear8bitLt(hidden_size, projection_size)
|
||||
ml.Linear(hidden_size, projection_size)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
|
||||
__all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152']
|
||||
|
@ -109,8 +109,8 @@ class FixupResNet(nn.Module):
|
|||
self.layer4 = self._make_layer(block, num_filters*8, layers[3], stride=2)
|
||||
self.bias2 = nn.Parameter(torch.zeros(1))
|
||||
reduced_img_sz = int(input_img_size / 32)
|
||||
self.fc1 = bnb.nn.Linear8bitLt(num_filters * 8 * reduced_img_sz * reduced_img_sz, 100)
|
||||
self.fc2 = bnb.nn.Linear8bitLt(100, num_classes)
|
||||
self.fc1 = ml.Linear(num_filters * 8 * reduced_img_sz * reduced_img_sz, 100)
|
||||
self.fc2 = ml.Linear(100, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, FixupBasicBlock):
|
||||
|
@ -125,7 +125,7 @@ class FixupResNet(nn.Module):
|
|||
if m.downsample is not None:
|
||||
nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:]))))
|
||||
'''
|
||||
elif isinstance(m, bnb.nn.Linear8bitLt):
|
||||
elif isinstance(m, ml.Linear):
|
||||
nn.init.constant_(m.weight, 0)
|
||||
nn.init.constant_(m.bias, 0)'''
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch.nn.functional as F
|
|||
from models.arch_util import ResBlock
|
||||
from models.lucidrains.x_transformers import Encoder
|
||||
from trainer.networks import register_model
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
|
||||
class VitLatent(nn.Module):
|
||||
|
@ -32,10 +32,10 @@ class VitLatent(nn.Module):
|
|||
do_checkpointing=True
|
||||
)
|
||||
|
||||
self.mlp = nn.Sequential(bnb.nn.Linear8bitLt(hidden_dim, hidden_dim*2),
|
||||
self.mlp = nn.Sequential(ml.Linear(hidden_dim, hidden_dim*2),
|
||||
nn.BatchNorm1d(hidden_dim*2),
|
||||
nn.ReLU(inplace=True),
|
||||
bnb.nn.Linear8bitLt(hidden_dim*2, hidden_dim))
|
||||
ml.Linear(hidden_dim*2, hidden_dim))
|
||||
|
||||
def provide_ema(self, ema):
|
||||
self.ema = ema
|
||||
|
|
|
@ -7,7 +7,7 @@ import torch.nn.functional as F
|
|||
from einops import rearrange, repeat
|
||||
|
||||
from rotary_embedding_torch import apply_rotary_emb
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
# helpers
|
||||
|
||||
|
@ -48,9 +48,9 @@ class Attention(nn.Module):
|
|||
self.stable = stable
|
||||
self.causal = causal
|
||||
|
||||
self.to_qkv = bnb.nn.Linear8bitLt(dim, inner_dim * 3, bias = False)
|
||||
self.to_qkv = ml.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Sequential(
|
||||
bnb.nn.Linear8bitLt(inner_dim, dim),
|
||||
ml.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
|
@ -103,10 +103,10 @@ class SparseConvCausalAttention(nn.Module):
|
|||
|
||||
self.stable = stable
|
||||
|
||||
self.to_qkv = bnb.nn.Linear8bitLt(dim, inner_dim * 3, bias = False)
|
||||
self.to_qkv = ml.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
bnb.nn.Linear8bitLt(inner_dim, dim),
|
||||
ml.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
|
@ -223,10 +223,10 @@ class SparseAxialCausalAttention(nn.Module):
|
|||
|
||||
self.stable = stable
|
||||
|
||||
self.to_qkv = bnb.nn.Linear8bitLt(dim, inner_dim * 3, bias = False)
|
||||
self.to_qkv = ml.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
bnb.nn.Linear8bitLt(inner_dim, dim),
|
||||
ml.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ from models.lucidrains.dalle.attention import Attention, SparseAttention, Sparse
|
|||
|
||||
from rotary_embedding_torch import RotaryEmbedding, broadcat
|
||||
from g_mlp_pytorch import gMLPBlock
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
# helpers
|
||||
|
||||
|
@ -79,10 +79,10 @@ class FeedForward(nn.Module):
|
|||
def __init__(self, dim, dropout = 0., mult = 4.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
bnb.nn.Linear8bitLt(dim, dim * mult * 2),
|
||||
ml.Linear(dim, dim * mult * 2),
|
||||
GEGLU(),
|
||||
nn.Dropout(dropout),
|
||||
bnb.nn.Linear8bitLt(dim * mult, dim)
|
||||
ml.Linear(dim * mult, dim)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -21,7 +21,7 @@ try:
|
|||
APEX_AVAILABLE = True
|
||||
except:
|
||||
APEX_AVAILABLE = False
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
# helpers
|
||||
|
||||
|
@ -357,10 +357,10 @@ class FeedForward(nn.Module):
|
|||
activation = default(activation, nn.GELU)
|
||||
|
||||
self.glu = glu
|
||||
self.w1 = bnb.nn.Linear8bitLt(dim, dim * mult * (2 if glu else 1))
|
||||
self.w1 = ml.Linear(dim, dim * mult * (2 if glu else 1))
|
||||
self.act = activation()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.w2 = bnb.nn.Linear8bitLt(dim * mult, dim)
|
||||
self.w2 = ml.Linear(dim * mult, dim)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
if not self.glu:
|
||||
|
@ -402,10 +402,10 @@ class Attention(nn.Module):
|
|||
self.global_heads = heads - local_heads
|
||||
self.local_attn = LocalAttention(window_size = local_window_size, causal = causal, autopad = True, dropout = dropout, look_forward = int(not causal), rel_pos_emb_config = (dim_head, local_heads)) if local_heads > 0 else None
|
||||
|
||||
self.to_q = bnb.nn.Linear8bitLt(dim, inner_dim, bias = qkv_bias)
|
||||
self.to_k = bnb.nn.Linear8bitLt(dim, inner_dim, bias = qkv_bias)
|
||||
self.to_v = bnb.nn.Linear8bitLt(dim, inner_dim, bias = qkv_bias)
|
||||
self.to_out = bnb.nn.Linear8bitLt(inner_dim, dim, bias = attn_out_bias)
|
||||
self.to_q = ml.Linear(dim, inner_dim, bias = qkv_bias)
|
||||
self.to_k = ml.Linear(dim, inner_dim, bias = qkv_bias)
|
||||
self.to_v = ml.Linear(dim, inner_dim, bias = qkv_bias)
|
||||
self.to_out = ml.Linear(inner_dim, dim, bias = attn_out_bias)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x, pos_emb = None, context = None, mask = None, context_mask = None, **kwargs):
|
||||
|
@ -460,7 +460,7 @@ class AbsolutePositionalEmbedding(nn.Module):
|
|||
def __init__(self, dim, max_seq_len):
|
||||
super().__init__()
|
||||
# nn.Embedding
|
||||
self.emb = bnb.nn.StableEmbedding(max_seq_len, dim)
|
||||
self.emb = ml.Embedding(max_seq_len, dim)
|
||||
|
||||
def forward(self, x):
|
||||
t = torch.arange(x.shape[1], device=x.device)
|
||||
|
@ -622,7 +622,7 @@ class PerformerLM(nn.Module):
|
|||
|
||||
self.max_seq_len = max_seq_len
|
||||
# nn.Embedding
|
||||
self.token_emb = bnb.nn.StableEmbedding(num_tokens, dim)
|
||||
self.token_emb = ml.Embedding(num_tokens, dim)
|
||||
|
||||
if rotary_position_emb:
|
||||
self.pos_emb = FixedPositionalEmbedding(dim, max_seq_len)
|
||||
|
@ -639,7 +639,7 @@ class PerformerLM(nn.Module):
|
|||
|
||||
self.performer = Performer(dim, depth, heads, dim_head, local_attn_heads, local_window_size, causal, ff_mult, nb_features, feature_redraw_interval, reversible, ff_chunks, generalized_attention, kernel_fn, use_scalenorm, use_rezero, ff_glu, ff_dropout, attn_dropout, cross_attend, no_projection, auto_check_redraw, qkv_bias, attn_out_bias, shift_tokens)
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.to_out = bnb.nn.Linear8bitLt(dim, num_tokens) if not tie_embed else None
|
||||
self.to_out = ml.Linear(dim, num_tokens) if not tie_embed else None
|
||||
|
||||
def check_redraw_projections(self):
|
||||
self.performer.check_redraw_projections()
|
||||
|
|
|
@ -8,7 +8,7 @@ from torch.cuda.amp import autocast
|
|||
|
||||
from einops import rearrange, repeat
|
||||
from contextlib import contextmanager
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
|
||||
def par(t, nm):
|
||||
|
@ -356,9 +356,9 @@ class VectorQuantize(nn.Module):
|
|||
|
||||
codebook_dim = default(codebook_dim, dim)
|
||||
requires_projection = codebook_dim != dim
|
||||
self.project_in = bnb.nn.Linear8bitLt(dim, codebook_dim) if requires_projection \
|
||||
self.project_in = ml.Linear(dim, codebook_dim) if requires_projection \
|
||||
else nn.Identity()
|
||||
self.project_out = bnb.nn.Linear8bitLt(codebook_dim, dim) if requires_projection \
|
||||
self.project_out = ml.Linear(codebook_dim, dim) if requires_projection \
|
||||
else nn.Identity()
|
||||
|
||||
self.eps = eps
|
||||
|
|
|
@ -11,7 +11,7 @@ from einops import rearrange, repeat, reduce
|
|||
from einops.layers.torch import Rearrange
|
||||
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
DEFAULT_DIM_HEAD = 64
|
||||
|
||||
|
@ -127,7 +127,7 @@ class AbsolutePositionalEmbedding(nn.Module):
|
|||
super().__init__()
|
||||
self.scale = dim ** -0.5
|
||||
# nn.Embedding
|
||||
self.emb = bnb.nn.StableEmbedding(max_seq_len, dim)
|
||||
self.emb = ml.Embedding(max_seq_len, dim)
|
||||
|
||||
def forward(self, x):
|
||||
n = torch.arange(x.shape[1], device=x.device)
|
||||
|
@ -157,7 +157,7 @@ class RelativePositionBias(nn.Module):
|
|||
self.num_buckets = num_buckets
|
||||
self.max_distance = max_distance
|
||||
# nn.Embedding
|
||||
self.relative_attention_bias = bnb.nn.StableEmbedding(num_buckets, heads)
|
||||
self.relative_attention_bias = ml.Embedding(num_buckets, heads)
|
||||
|
||||
@staticmethod
|
||||
def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
|
||||
|
@ -363,7 +363,7 @@ class RMSScaleShiftNorm(nn.Module):
|
|||
self.cdim = 1
|
||||
self.pdim = -1
|
||||
else:
|
||||
self.scale_shift_process = bnb.nn.Linear8bitLt(embed_dim, dim * 2, bias=bias)
|
||||
self.scale_shift_process = ml.Linear(embed_dim, dim * 2, bias=bias)
|
||||
self.cdim = -1
|
||||
self.pdim = 1
|
||||
|
||||
|
@ -450,7 +450,7 @@ class GLU(nn.Module):
|
|||
def __init__(self, dim_in, dim_out, activation):
|
||||
super().__init__()
|
||||
self.act = activation
|
||||
self.proj = bnb.nn.Linear8bitLt(dim_in, dim_out * 2)
|
||||
self.proj = ml.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
|
@ -475,7 +475,7 @@ class FeedForward(nn.Module):
|
|||
activation = ReluSquared() if relu_squared else nn.GELU()
|
||||
|
||||
project_in = nn.Sequential(
|
||||
bnb.nn.Linear8bitLt(dim, inner_dim),
|
||||
ml.Linear(dim, inner_dim),
|
||||
activation
|
||||
) if not glu else GLU(dim, inner_dim, activation)
|
||||
|
||||
|
@ -483,7 +483,7 @@ class FeedForward(nn.Module):
|
|||
project_in,
|
||||
nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(),
|
||||
nn.Dropout(dropout),
|
||||
bnb.nn.Linear8bitLt(inner_dim, dim_out)
|
||||
ml.Linear(inner_dim, dim_out)
|
||||
)
|
||||
|
||||
# init last linear layer to 0
|
||||
|
@ -538,16 +538,16 @@ class Attention(nn.Module):
|
|||
qk_dim = int(collab_compression * qk_dim)
|
||||
self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim))
|
||||
|
||||
self.to_q = bnb.nn.Linear8bitLt(dim, qk_dim, bias=False)
|
||||
self.to_k = bnb.nn.Linear8bitLt(dim, qk_dim, bias=False)
|
||||
self.to_v = bnb.nn.Linear8bitLt(dim, v_dim, bias=False)
|
||||
self.to_q = ml.Linear(dim, qk_dim, bias=False)
|
||||
self.to_k = ml.Linear(dim, qk_dim, bias=False)
|
||||
self.to_v = ml.Linear(dim, v_dim, bias=False)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
# add GLU gating for aggregated values, from alphafold2
|
||||
self.to_v_gate = None
|
||||
if gate_values:
|
||||
self.to_v_gate = bnb.nn.Linear8bitLt(dim, v_dim)
|
||||
self.to_v_gate = ml.Linear(dim, v_dim)
|
||||
nn.init.constant_(self.to_v_gate.weight, 0)
|
||||
nn.init.constant_(self.to_v_gate.bias, 1)
|
||||
|
||||
|
@ -584,7 +584,7 @@ class Attention(nn.Module):
|
|||
# attention on attention
|
||||
self.attn_on_attn = on_attn
|
||||
out_dim = default(out_dim, dim)
|
||||
self.to_out = nn.Sequential(bnb.nn.Linear8bitLt(v_dim, out_dim * 2), nn.GLU()) if on_attn else bnb.nn.Linear8bitLt(v_dim, out_dim)
|
||||
self.to_out = nn.Sequential(ml.Linear(v_dim, out_dim * 2), nn.GLU()) if on_attn else ml.Linear(v_dim, out_dim)
|
||||
|
||||
self.rel_pos_bias = rel_pos_bias
|
||||
if rel_pos_bias:
|
||||
|
@ -1080,7 +1080,7 @@ class ViTransformerWrapper(nn.Module):
|
|||
self.patch_size = patch_size
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.patch_to_embedding = bnb.nn.Linear8bitLt(patch_dim, dim)
|
||||
self.patch_to_embedding = ml.Linear(patch_dim, dim)
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
|
@ -1139,18 +1139,18 @@ class TransformerWrapper(nn.Module):
|
|||
self.shift_mem_down = shift_mem_down
|
||||
|
||||
# nn.Embedding
|
||||
self.token_emb = bnb.nn.StableEmbedding(num_tokens, emb_dim)
|
||||
self.token_emb = ml.Embedding(num_tokens, emb_dim)
|
||||
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
|
||||
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
|
||||
self.emb_dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.project_emb = bnb.nn.Linear8bitLt(emb_dim, dim) if emb_dim != dim else nn.Identity()
|
||||
self.project_emb = ml.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
|
||||
self.attn_layers = attn_layers
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.init_()
|
||||
|
||||
self.to_logits = bnb.nn.Linear8bitLt(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
|
||||
self.to_logits = ml.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
|
||||
|
||||
# memory tokens (like [cls]) from Memory Transformers paper
|
||||
num_memory_tokens = default(num_memory_tokens, 0)
|
||||
|
@ -1237,12 +1237,12 @@ class ContinuousTransformerWrapper(nn.Module):
|
|||
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
|
||||
self.emb_dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.project_in = bnb.nn.Linear8bitLt(dim_in, dim) if exists(dim_in) else nn.Identity()
|
||||
self.project_in = ml.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
|
||||
|
||||
self.attn_layers = attn_layers
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.project_out = bnb.nn.Linear8bitLt(dim, dim_out) if exists(dim_out) else nn.Identity()
|
||||
self.project_out = ml.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch.nn.functional as F
|
|||
from torch import einsum
|
||||
|
||||
from utils.weight_scheduler import LinearDecayWeightScheduler
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
|
||||
class GumbelQuantizer(nn.Module):
|
||||
|
@ -12,7 +12,7 @@ class GumbelQuantizer(nn.Module):
|
|||
super().__init__()
|
||||
self.to_logits = nn.Conv1d(inp_dim, num_tokens, 1)
|
||||
# nn.Embedding
|
||||
self.codebook = bnb.nn.StableEmbedding(num_tokens, codebook_dim)
|
||||
self.codebook = ml.Embedding(num_tokens, codebook_dim)
|
||||
self.straight_through = straight_through
|
||||
self.temperature_scheduler = LinearDecayWeightScheduler(10, 5000, .9, 2000)
|
||||
self.step = 0
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch.nn.functional as F
|
|||
from einops import rearrange, repeat
|
||||
|
||||
from models.arch_util import l2norm, sample_vectors, default, ema_inplace
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
|
||||
def kmeans(samples, num_clusters, num_iters = 10, use_cosine_sim = False):
|
||||
|
@ -185,8 +185,8 @@ class VectorQuantize(nn.Module):
|
|||
|
||||
codebook_dim = default(codebook_dim, dim)
|
||||
requires_projection = codebook_dim != dim
|
||||
self.project_in = bnb.nn.Linear8bitLt(dim, codebook_dim) if requires_projection else nn.Identity()
|
||||
self.project_out = bnb.nn.Linear8bitLt(codebook_dim, dim) if requires_projection else nn.Identity()
|
||||
self.project_in = ml.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
|
||||
self.project_out = ml.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
|
||||
|
||||
self.eps = eps
|
||||
|
||||
|
|
41
codes/torch_intermediary/__init__.py
Normal file
41
codes/torch_intermediary/__init__.py
Normal file
|
@ -0,0 +1,41 @@
|
|||
"""
|
||||
from bitsandbytes.nn import Linear8bitLt as Linear
|
||||
from bitsandbytes.nn import StableEmbedding as Embedding
|
||||
from bitsandbytes.optim.adam import Adam8bit as Adam
|
||||
from bitsandbytes.optim.adamw import AdamW8bit as AdamW
|
||||
"""
|
||||
"""
|
||||
from torch.nn import Linear
|
||||
from torch.nn import Embedding
|
||||
from torch.optim.adam import Adam
|
||||
from torch.optim.adamw import AdamW
|
||||
"""
|
||||
|
||||
OVERRIDE_LINEAR = False
|
||||
OVERRIDE_EMBEDDING = False
|
||||
OVERRIDE_ADAM = True
|
||||
OVERRIDE_ADAMW = True
|
||||
USE_STABLE_EMBEDDING = True
|
||||
|
||||
if OVERRIDE_LINEAR:
|
||||
from bitsandbytes.nn import Linear8bitLt as Linear
|
||||
else:
|
||||
from torch.nn import Linear
|
||||
|
||||
if OVERRIDE_EMBEDDING:
|
||||
if USE_STABLE_EMBEDDING:
|
||||
from bitsandbytes.nn import StableEmbedding as Embedding
|
||||
else:
|
||||
from bitsandbytes.nn import Embedding as Embedding
|
||||
else:
|
||||
from torch.nn import Embedding
|
||||
|
||||
if OVERRIDE_ADAM:
|
||||
from bitsandbytes.optim.adam import Adam8bit as Adam
|
||||
else:
|
||||
from torch.optim.adam import Adam
|
||||
|
||||
if OVERRIDE_ADAMW:
|
||||
from bitsandbytes.optim.adamw import AdamW8bit as AdamW
|
||||
else:
|
||||
from torch.optim.adamw import AdamW
|
|
@ -21,7 +21,7 @@ import torchvision.utils as utils
|
|||
|
||||
from utils.loss_accumulator import LossAccumulator, InfStorageLossAccumulator
|
||||
from utils.util import opt_get, denormalize
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
|
@ -338,7 +338,7 @@ class ExtensibleTrainer(BaseModel):
|
|||
for net in self.networks.values():
|
||||
for mod in net.modules():
|
||||
fan_in = -1
|
||||
if isinstance(mod, bnb.nn.Linear8bitLt):
|
||||
if isinstance(mod, ml.Linear):
|
||||
fan_in = mod.weight.data.shape[1]
|
||||
elif isinstance(mod, nn.Conv1d):
|
||||
fan_in = mod.weight.data.shape[0]
|
||||
|
|
|
@ -7,7 +7,7 @@ import torch.nn as nn
|
|||
import trainer.networks as networks
|
||||
import trainer.lr_scheduler as lr_scheduler
|
||||
from .base_model import BaseModel
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
|
@ -43,7 +43,7 @@ class FeatureModel(BaseModel):
|
|||
if self.rank <= 0:
|
||||
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
||||
# torch.optim.Adam
|
||||
self.optimizer_G = bnb.optim.Adam8bit(optim_params, lr=train_opt['lr_G'],
|
||||
self.optimizer_G = ml.Adam(optim_params, lr=train_opt['lr_G'],
|
||||
weight_decay=wd_G,
|
||||
betas=(train_opt['beta1_G'], train_opt['beta2_G']))
|
||||
self.optimizers.append(self.optimizer_G)
|
||||
|
|
|
@ -3,7 +3,7 @@ from collections import Counter
|
|||
from collections import defaultdict
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
from utils.util import opt_get
|
||||
|
||||
|
@ -137,7 +137,7 @@ class CosineAnnealingLR_Restart(_LRScheduler):
|
|||
|
||||
if __name__ == "__main__":
|
||||
#torch.optim.Adam
|
||||
optimizer = bnb.optim.Adam8bit([torch.zeros(3, 64, 3, 3)], lr=1e-4, weight_decay=0,
|
||||
optimizer = ml.Adam([torch.zeros(3, 64, 3, 3)], lr=1e-4, weight_decay=0,
|
||||
betas=(0.9, 0.99))
|
||||
##############################
|
||||
# MultiStepLR_Restart
|
||||
|
|
|
@ -12,7 +12,7 @@ from utils.util import recursively_detach, opt_get, clip_grad_norm
|
|||
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch_intermediary as ml
|
||||
|
||||
# Defines the expected API for a single training step
|
||||
class ConfigurableStep(Module):
|
||||
|
@ -84,7 +84,7 @@ class ConfigurableStep(Module):
|
|||
norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d,
|
||||
nn.BatchNorm3d, nn.InstanceNorm3d, nn.GroupNorm, nn.LayerNorm)
|
||||
# nn.Embedding
|
||||
emb_modules = (bnb.nn.StableEmbedding, nn.EmbeddingBag)
|
||||
emb_modules = (ml.Embedding, nn.EmbeddingBag)
|
||||
param_names_notweights = set()
|
||||
all_param_names = set()
|
||||
param_map = {}
|
||||
|
@ -126,7 +126,7 @@ class ConfigurableStep(Module):
|
|||
{ 'params': params_notweights, 'weight_decay': 0 }
|
||||
]
|
||||
# torch.optim.AdamW
|
||||
opt = bnb.optim.AdamW8bit(groups, lr=opt_config['lr'],
|
||||
opt = ml.AdamW(groups, lr=opt_config['lr'],
|
||||
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
|
||||
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
|
||||
opt._group_names = [params_names_weights, params_names_notweights]
|
||||
|
@ -145,7 +145,7 @@ class ConfigurableStep(Module):
|
|||
# parameters and just use a normal AdamW implementation. In a large network, these weights will normally
|
||||
# be a tiny fraction of the total weights.
|
||||
# torch.optim.AdamW
|
||||
opt_unweighted = bnb.optim.AdamW8bit(params_notweights, lr=opt_config['lr'], weight_decay=0,
|
||||
opt_unweighted = ml.AdamW(params_notweights, lr=opt_config['lr'], weight_decay=0,
|
||||
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
|
||||
opt_unweighted._config = opt_config
|
||||
opt_unweighted._config['network'] = net_name
|
||||
|
@ -153,7 +153,7 @@ class ConfigurableStep(Module):
|
|||
self.optimizers.append(opt_unweighted)
|
||||
|
||||
# torch.optim.AdamW
|
||||
opt = ZeroRedundancyOptimizer(params_weights, optimizer_class=bnb.optim.AdamW8bit, lr=opt_config['lr'],
|
||||
opt = ZeroRedundancyOptimizer(params_weights, optimizer_class=ml.AdamW, lr=opt_config['lr'],
|
||||
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
|
||||
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
|
||||
opt.param_groups[0]['initial_lr'] = opt_config['lr']
|
||||
|
@ -168,7 +168,7 @@ class ConfigurableStep(Module):
|
|||
elif self.step_opt['optimizer'] == 'lamb':
|
||||
from trainer.optimizers.lamb import Lamb
|
||||
# torch.optim.AdamW
|
||||
opt_unweighted = bnb.optim.AdamW8bit(params_notweights, lr=opt_config['lr'], weight_decay=0,
|
||||
opt_unweighted = ml.AdamW(params_notweights, lr=opt_config['lr'], weight_decay=0,
|
||||
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
|
||||
opt_unweighted._config = opt_config
|
||||
opt_unweighted._config['network'] = net_name
|
||||
|
|
Loading…
Reference in New Issue
Block a user