Merge pull request 'bitsandbytes' (#2) from bitsandbytes into master

Reviewed-on: mrq/DL-Art-School#2
This commit is contained in:
mrq 2023-02-23 03:16:25 +00:00
commit 918473807f
75 changed files with 615 additions and 212 deletions

View File

@ -0,0 +1,54 @@
import ctypes as ct
from pathlib import Path
from warnings import warn
from .cuda_setup.main import evaluate_cuda_setup
class CUDALibrary_Singleton(object):
_instance = None
def __init__(self):
raise RuntimeError("Call get_instance() instead")
def initialize(self):
binary_name = evaluate_cuda_setup()
package_dir = Path(__file__).parent
binary_path = package_dir / binary_name
if not binary_path.exists():
print(f"CUDA SETUP: TODO: compile library for specific version: {binary_name}")
legacy_binary_name = "libbitsandbytes.so"
print(f"CUDA SETUP: Defaulting to {legacy_binary_name}...")
binary_path = package_dir / legacy_binary_name
if not binary_path.exists():
print('CUDA SETUP: CUDA detection failed. Either CUDA driver not installed, CUDA not installed, or you have multiple conflicting CUDA libraries!')
print('CUDA SETUP: If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION` for example, `make CUDA_VERSION=113`.')
raise Exception('CUDA SETUP: Setup Failed!')
# self.lib = ct.cdll.LoadLibrary(binary_path)
self.lib = ct.cdll.LoadLibrary(str(binary_path)) # $$$
else:
print(f"CUDA SETUP: Loading binary {binary_path}...")
# self.lib = ct.cdll.LoadLibrary(binary_path)
self.lib = ct.cdll.LoadLibrary(str(binary_path)) # $$$
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls.__new__(cls)
cls._instance.initialize()
return cls._instance
lib = CUDALibrary_Singleton.get_instance().lib
try:
lib.cadam32bit_g32
lib.get_context.restype = ct.c_void_p
lib.get_cusparse.restype = ct.c_void_p
COMPILED_WITH_CUDA = True
except AttributeError:
warn(
"The installed version of bitsandbytes was compiled without GPU support. "
"8-bit optimizers and GPU quantization are unavailable."
)
COMPILED_WITH_CUDA = False

View File

@ -0,0 +1,166 @@
"""
extract factors the build is dependent on:
[X] compute capability
[ ] TODO: Q - What if we have multiple GPUs of different makes?
- CUDA version
- Software:
- CPU-only: only CPU quantization functions (no optimizer, no matrix multiple)
- CuBLAS-LT: full-build 8-bit optimizer
- no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`)
evaluation:
- if paths faulty, return meaningful error
- else:
- determine CUDA version
- determine capabilities
- based on that set the default path
"""
import ctypes
from .paths import determine_cuda_runtime_lib_path
def check_cuda_result(cuda, result_val):
# 3. Check for CUDA errors
if result_val != 0:
error_str = ctypes.c_char_p()
cuda.cuGetErrorString(result_val, ctypes.byref(error_str))
print(f"CUDA exception! Error code: {error_str.value.decode()}")
def get_cuda_version(cuda, cudart_path):
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION
try:
cudart = ctypes.CDLL(cudart_path)
except OSError:
# TODO: shouldn't we error or at least warn here?
print(f'ERROR: libcudart.so could not be read from path: {cudart_path}!')
return None
version = ctypes.c_int()
check_cuda_result(cuda, cudart.cudaRuntimeGetVersion(ctypes.byref(version)))
version = int(version.value)
major = version//1000
minor = (version-(major*1000))//10
if major < 11:
print('CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!')
return f'{major}{minor}'
def get_cuda_lib_handle():
# 1. find libcuda.so library (GPU driver) (/usr/lib)
try:
cuda = ctypes.CDLL("libcuda.so")
except OSError:
# TODO: shouldn't we error or at least warn here?
print('CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!')
return None
check_cuda_result(cuda, cuda.cuInit(0))
return cuda
def get_compute_capabilities(cuda):
"""
1. find libcuda.so library (GPU driver) (/usr/lib)
init_device -> init variables -> call function by reference
2. call extern C function to determine CC
(https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html)
3. Check for CUDA errors
https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api
# bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
"""
nGpus = ctypes.c_int()
cc_major = ctypes.c_int()
cc_minor = ctypes.c_int()
device = ctypes.c_int()
check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus)))
ccs = []
for i in range(nGpus.value):
check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i))
ref_major = ctypes.byref(cc_major)
ref_minor = ctypes.byref(cc_minor)
# 2. call extern C function to determine CC
check_cuda_result(
cuda, cuda.cuDeviceComputeCapability(ref_major, ref_minor, device)
)
ccs.append(f"{cc_major.value}.{cc_minor.value}")
return ccs
# def get_compute_capability()-> Union[List[str, ...], None]: # FIXME: error
def get_compute_capability(cuda):
"""
Extracts the highest compute capbility from all available GPUs, as compute
capabilities are downwards compatible. If no GPUs are detected, it returns
None.
"""
ccs = get_compute_capabilities(cuda)
if ccs is not None:
# TODO: handle different compute capabilities; for now, take the max
return ccs[-1]
return None
def evaluate_cuda_setup():
print('')
print('='*35 + 'BUG REPORT' + '='*35)
print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')
print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
print('='*80)
return "libbitsandbytes_cuda116.dll" # $$$
binary_name = "libbitsandbytes_cpu.so"
#if not torch.cuda.is_available():
#print('No GPU detected. Loading CPU library...')
#return binary_name
cudart_path = determine_cuda_runtime_lib_path()
if cudart_path is None:
print(
"WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!"
)
return binary_name
print(f"CUDA SETUP: CUDA runtime path found: {cudart_path}")
cuda = get_cuda_lib_handle()
cc = get_compute_capability(cuda)
print(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}")
cuda_version_string = get_cuda_version(cuda, cudart_path)
if cc == '':
print(
"WARNING: No GPU detected! Check your CUDA paths. Processing to load CPU-only library..."
)
return binary_name
# 7.5 is the minimum CC vor cublaslt
has_cublaslt = cc in ["7.5", "8.0", "8.6"]
# TODO:
# (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible)
# (2) Multiple CUDA versions installed
# we use ls -l instead of nvcc to determine the cuda version
# since most installations will have the libcudart.so installed, but not the compiler
print(f'CUDA SETUP: Detected CUDA version {cuda_version_string}')
def get_binary_name():
"if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so"
bin_base_name = "libbitsandbytes_cuda"
if has_cublaslt:
return f"{bin_base_name}{cuda_version_string}.so"
else:
return f"{bin_base_name}{cuda_version_string}_nocublaslt.so"
binary_name = get_binary_name()
return binary_name

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .modules import Int8Params, Linear8bit, Linear8bitLt
from .modules import Embedding as StableEmbedding

View File

@ -9,6 +9,7 @@ import torch.nn.utils.spectral_norm as SpectralNorm
from math import sqrt from math import sqrt
from utils.util import checkpoint from utils.util import checkpoint
import torch_intermediary as ml
def exists(val): def exists(val):
@ -73,7 +74,7 @@ def initialize_weights(net_l, scale=1):
m.weight.data *= scale # for residual block m.weight.data *= scale # for residual block
if m.bias is not None: if m.bias is not None:
m.bias.data.zero_() m.bias.data.zero_()
elif isinstance(m, nn.Linear): elif isinstance(m, ml.Linear):
init.kaiming_normal_(m.weight, a=0, mode='fan_in') init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale m.weight.data *= scale
if m.bias is not None: if m.bias is not None:
@ -108,7 +109,7 @@ def default_init_weights(module, scale=1):
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
kaiming_init(m, a=0, mode='fan_in', bias=0) kaiming_init(m, a=0, mode='fan_in', bias=0)
m.weight.data *= scale m.weight.data *= scale
elif isinstance(m, nn.Linear): elif isinstance(m, ml.Linear):
kaiming_init(m, a=0, mode='fan_in', bias=0) kaiming_init(m, a=0, mode='fan_in', bias=0)
m.weight.data *= scale m.weight.data *= scale
@ -141,7 +142,7 @@ def linear(*args, **kwargs):
""" """
Create a linear module. Create a linear module.
""" """
return nn.Linear(*args, **kwargs) return ml.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs): def avg_pool_nd(dims, *args, **kwargs):

View File

@ -9,6 +9,7 @@ from data.audio.unsupervised_audio_dataset import load_audio
from models.audio.tts.tacotron2.text import sequence_to_text from models.audio.tts.tacotron2.text import sequence_to_text
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import opt_get from utils.util import opt_get
import torch_intermediary as ml
def only_letters(string): def only_letters(string):
@ -51,7 +52,7 @@ class Wav2VecWrapper(nn.Module):
self.w2v = Wav2Vec2ForCTC.from_pretrained(basis_model) self.w2v = Wav2Vec2ForCTC.from_pretrained(basis_model)
# Perform some surgery to get the model we actually want. # Perform some surgery to get the model we actually want.
self.w2v.wav2vec2.encoder.gradient_checkpointing = checkpointing_enabled self.w2v.wav2vec2.encoder.gradient_checkpointing = checkpointing_enabled
self.w2v.lm_head = nn.Linear(self.w2v.config.hidden_size, vocab_size) self.w2v.lm_head = ml.Linear(self.w2v.config.hidden_size, vocab_size)
self.w2v.config.vocab_size = vocab_size self.w2v.config.vocab_size = vocab_size
self.w2v.config.pad_token_id = 0 self.w2v.config.pad_token_id = 0
self.w2v.config.ctc_loss_reduction = 'sum' self.w2v.config.ctc_loss_reduction = 'sum'

View File

@ -5,6 +5,7 @@ import torch.nn as nn
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import opt_get from utils.util import opt_get
from typing import Type, Any, Callable, Union, List, Optional from typing import Type, Any, Callable, Union, List, Optional
import torch_intermediary as ml
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
@ -172,7 +173,7 @@ class ResNet(nn.Module):
self.layer4 = self._make_layer(block, 512, layers[3], stride=4, self.layer4 = self._make_layer(block, 512, layers[3], stride=4,
dilate=replace_stride_with_dilation[2]) dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool1d(1) self.avgpool = nn.AdaptiveAvgPool1d(1)
self.fc = nn.Linear(512 * block.expansion, num_classes) self.fc = ml.Linear(512 * block.expansion, num_classes)
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv1d): if isinstance(m, nn.Conv1d):

View File

@ -15,13 +15,14 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
from models.arch_util import ResBlock from models.arch_util import ResBlock
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import checkpoint from utils.util import checkpoint
import torch_intermediary as ml
class Mel2Vec2FeatureProjection(nn.Module): class Mel2Vec2FeatureProjection(nn.Module):
def __init__(self, inner_dim, dropout): def __init__(self, inner_dim, dropout):
super().__init__() super().__init__()
self.layer_norm = nn.LayerNorm(inner_dim, eps=1e-5) self.layer_norm = nn.LayerNorm(inner_dim, eps=1e-5)
self.projection = nn.Linear(inner_dim, inner_dim) self.projection = ml.Linear(inner_dim, inner_dim)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
def forward(self, hidden_states): def forward(self, hidden_states):
@ -58,10 +59,10 @@ class Wav2Vec2Attention(nn.Module):
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder self.is_decoder = is_decoder
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.k_proj = ml.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = ml.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = ml.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = ml.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
@ -182,10 +183,10 @@ class Wav2Vec2FeedForward(nn.Module):
super().__init__() super().__init__()
self.intermediate_dropout = nn.Dropout(dropout) self.intermediate_dropout = nn.Dropout(dropout)
self.intermediate_dense = nn.Linear(hidden_size, intermediate_size) self.intermediate_dense = ml.Linear(hidden_size, intermediate_size)
self.intermediate_act_fn = F.gelu self.intermediate_act_fn = F.gelu
self.output_dense = nn.Linear(intermediate_size, hidden_size) self.output_dense = ml.Linear(intermediate_size, hidden_size)
self.output_dropout = nn.Dropout(dropout) self.output_dropout = nn.Dropout(dropout)
def forward(self, hidden_states): def forward(self, hidden_states):
@ -429,7 +430,7 @@ class Mel2Vec(nn.Module):
k = math.sqrt(1 / module.projection.in_features) k = math.sqrt(1 / module.projection.in_features)
nn.init.uniform_(module.projection.weight, a=-k, b=k) nn.init.uniform_(module.projection.weight, a=-k, b=k)
nn.init.uniform_(module.projection.bias, a=-k, b=k) nn.init.uniform_(module.projection.bias, a=-k, b=k)
elif isinstance(module, nn.Linear): elif isinstance(module, ml.Linear):
if self.disable_custom_linear_init: if self.disable_custom_linear_init:
return return
module.weight.data.normal_(mean=0.0, std=self.linear_init_scale) module.weight.data.normal_(mean=0.0, std=self.linear_init_scale)
@ -510,7 +511,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
self.codevectors = nn.Parameter( self.codevectors = nn.Parameter(
torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups) torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups)
) )
self.weight_proj = nn.Linear(proj_dim, self.num_groups * self.num_vars) self.weight_proj = ml.Linear(proj_dim, self.num_groups * self.num_vars)
# can be decayed for training # can be decayed for training
self.temperature = 2 self.temperature = 2
@ -606,8 +607,8 @@ class ContrastiveTrainingWrapper(nn.Module):
self.inp_length_factor = inp_length_multiplier self.inp_length_factor = inp_length_multiplier
# make sure that project_hid & project_q are initialized like normal linear layers # make sure that project_hid & project_q are initialized like normal linear layers
self.project_hid = nn.Linear(inner_dim, self.quantizer.codevector_dim) self.project_hid = ml.Linear(inner_dim, self.quantizer.codevector_dim)
self.project_q = nn.Linear(self.quantizer.codevector_dim, self.quantizer.codevector_dim) self.project_q = ml.Linear(self.quantizer.codevector_dim, self.quantizer.codevector_dim)
self.reconstruction = do_reconstruction_loss self.reconstruction = do_reconstruction_loss
if do_reconstruction_loss: if do_reconstruction_loss:

View File

@ -2,6 +2,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from transformers import GPT2Config, GPT2Model from transformers import GPT2Config, GPT2Model
import torch_intermediary as ml
from models.arch_util import AttentionBlock, ResBlock from models.arch_util import AttentionBlock, ResBlock
from models.audio.tts.lucidrains_dvae import DiscreteVAE from models.audio.tts.lucidrains_dvae import DiscreteVAE
@ -55,8 +56,9 @@ class ConditioningAR(nn.Module):
self.gpt = GPT2Model(self.config) self.gpt = GPT2Model(self.config)
del self.gpt.wte # Unused, we'll do our own embeddings. del self.gpt.wte # Unused, we'll do our own embeddings.
self.embeddings = nn.Embedding(num_vectors, dim) # nn.Embedding
self.head = nn.Linear(dim, num_vectors) self.embeddings = ml.Embedding(num_vectors, dim)
self.head = ml.Linear(dim, num_vectors)
def forward(self, cheater_codes, conditioning, code_lengths=None, return_latent=False): def forward(self, cheater_codes, conditioning, code_lengths=None, return_latent=False):
unused_params = [] unused_params = []

View File

@ -17,6 +17,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch_intermediary as ml
from math import sqrt from math import sqrt
@ -24,7 +25,7 @@ from torch.utils.checkpoint import checkpoint
from trainer.networks import register_model from trainer.networks import register_model
Linear = nn.Linear Linear = ml.Linear
ConvTranspose2d = nn.ConvTranspose2d ConvTranspose2d = nn.ConvTranspose2d

View File

@ -4,6 +4,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import autocast from torch import autocast
import torch_intermediary as ml
from models.arch_util import ResBlock from models.arch_util import ResBlock
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
@ -22,7 +23,8 @@ def is_sequence(t):
class MultiGroupEmbedding(nn.Module): class MultiGroupEmbedding(nn.Module):
def __init__(self, tokens, groups, dim): def __init__(self, tokens, groups, dim):
super().__init__() super().__init__()
self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)]) # nn.Embedding
self.m = nn.ModuleList([ml.Embedding(tokens, dim // groups) for _ in range(groups)])
def forward(self, x): def forward(self, x):
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)] h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
@ -158,7 +160,8 @@ class FlatDiffusion(nn.Module):
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive # complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
# transformer network. # transformer network.
if in_groups is None: if in_groups is None:
self.embeddings = nn.Embedding(token_count, model_channels) # nn.Embedding
self.embeddings = ml.Embedding(token_count, model_channels)
else: else:
self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels) self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels)
self.latent_conditioner = nn.Sequential( self.latent_conditioner = nn.Sequential(

View File

@ -2,6 +2,7 @@ import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import GPT2Config, GPT2Model from transformers import GPT2Config, GPT2Model
import torch_intermediary as ml
from models.arch_util import AttentionBlock, ResBlock from models.arch_util import AttentionBlock, ResBlock
from models.audio.music.music_quantizer import MusicQuantizer from models.audio.music.music_quantizer import MusicQuantizer
@ -136,8 +137,9 @@ class GptMusicLower(nn.Module):
self.gpt = GPT2Model(self.config) self.gpt = GPT2Model(self.config)
del self.gpt.wte # Unused, we'll do our own embeddings. del self.gpt.wte # Unused, we'll do our own embeddings.
self.embeddings = nn.ModuleList([nn.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)]) # nn.Embedding
self.heads = nn.ModuleList([nn.Linear(dim, num_target_vectors) for _ in range(num_vaes)]) self.embeddings = nn.ModuleList([ml.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)])
self.heads = nn.ModuleList([ml.Linear(dim, num_target_vectors) for _ in range(num_vaes)])
def forward(self, mel, conditioning, return_latent=False): def forward(self, mel, conditioning, return_latent=False):
unused_params = [] unused_params = []
@ -238,8 +240,9 @@ class GptMusicUpper(nn.Module):
self.gpt = GPT2Model(self.config) self.gpt = GPT2Model(self.config)
del self.gpt.wte # Unused, we'll do our own embeddings. del self.gpt.wte # Unused, we'll do our own embeddings.
self.embeddings = nn.ModuleList([nn.Embedding(num_upper_vectors, dim // num_upper_groups) for _ in range(num_upper_groups)]) # nn.Embedding
self.heads = nn.ModuleList([nn.Linear(dim, num_upper_vectors) for _ in range(num_upper_groups)]) self.embeddings = nn.ModuleList([ml.Embedding(num_upper_vectors, dim // num_upper_groups) for _ in range(num_upper_groups)])
self.heads = nn.ModuleList([ml.Linear(dim, num_upper_vectors) for _ in range(num_upper_groups)])
def forward(self, mel, conditioning, return_latent=False): def forward(self, mel, conditioning, return_latent=False):

View File

@ -2,6 +2,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from transformers import GPT2Config, GPT2Model from transformers import GPT2Config, GPT2Model
import torch_intermediary as ml
from models.arch_util import AttentionBlock, ResBlock from models.arch_util import AttentionBlock, ResBlock
from models.audio.tts.lucidrains_dvae import DiscreteVAE from models.audio.tts.lucidrains_dvae import DiscreteVAE
@ -73,8 +74,9 @@ class GptMusicLower(nn.Module):
self.gpt = GPT2Model(self.config) self.gpt = GPT2Model(self.config)
del self.gpt.wte # Unused, we'll do our own embeddings. del self.gpt.wte # Unused, we'll do our own embeddings.
self.embeddings = nn.ModuleList([nn.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)]) # nn.Embedding
self.heads = nn.ModuleList([nn.Linear(dim, num_target_vectors) for _ in range(num_vaes)]) self.embeddings = nn.ModuleList([ml.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)])
self.heads = nn.ModuleList([ml.Linear(dim, num_target_vectors) for _ in range(num_vaes)])
def forward(self, mel, return_latent=False): def forward(self, mel, return_latent=False):
unused_params = [] unused_params = []

View File

@ -3,6 +3,7 @@ import functools
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch_intermediary as ml
from models.diffusion.nn import timestep_embedding from models.diffusion.nn import timestep_embedding
from models.lucidrains.vq import VectorQuantize from models.lucidrains.vq import VectorQuantize
@ -21,8 +22,8 @@ class SelfClassifyingHead(nn.Module):
use_rmsnorm=True, ff_glu=True, do_checkpointing=False) use_rmsnorm=True, ff_glu=True, do_checkpointing=False)
self.quantizer = VectorQuantize(out_dim, classes, use_cosine_sim=False, threshold_ema_dead_code=2, self.quantizer = VectorQuantize(out_dim, classes, use_cosine_sim=False, threshold_ema_dead_code=2,
sample_codebook_temp=init_temperature) sample_codebook_temp=init_temperature)
self.to_output = nn.Linear(dim, out_dim) self.to_output = ml.Linear(dim, out_dim)
self.to_decoder = nn.Linear(out_dim, dim) self.to_decoder = ml.Linear(out_dim, dim)
def do_ar_step(self, x, used_codes): def do_ar_step(self, x, used_codes):
h = self.dec(x) h = self.dec(x)
@ -90,7 +91,7 @@ class InstrumentQuantizer(nn.Module):
""" """
super().__init__() super().__init__()
self.op_dim = op_dim self.op_dim = op_dim
self.proj = nn.Linear(op_dim, dim) self.proj = ml.Linear(op_dim, dim)
self.encoder = nn.ModuleList([VectorResBlock(dim, dropout) for _ in range(enc_depth)]) self.encoder = nn.ModuleList([VectorResBlock(dim, dropout) for _ in range(enc_depth)])
self.heads = SelfClassifyingHead(dim, num_classes, op_dim, head_depth, class_seq_len, dropout, max_temp) self.heads = SelfClassifyingHead(dim, num_classes, op_dim, head_depth, class_seq_len, dropout, max_temp)
self.min_gumbel_temperature = min_temp self.min_gumbel_temperature = min_temp

View File

@ -1,6 +1,7 @@
import torch import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
import torch_intermediary as ml
from transformers import GPT2Config, GPT2Model from transformers import GPT2Config, GPT2Model
from trainer.networks import register_model from trainer.networks import register_model
@ -17,8 +18,9 @@ class Mel2VecCodesGpt(nn.Module):
n_inner=dim*2) n_inner=dim*2)
self.gpt = GPT2Model(self.config) self.gpt = GPT2Model(self.config)
del self.gpt.wte # Unused, we'll do our own embeddings. del self.gpt.wte # Unused, we'll do our own embeddings.
self.embeddings = nn.ModuleList([nn.Embedding(num_vectors, dim//num_groups) for _ in range(num_groups)]) # nn.Embedding
self.heads = nn.ModuleList([nn.Linear(dim, num_vectors) for _ in range(num_groups)]) self.embeddings = nn.ModuleList([ml.Embedding(num_vectors, dim//num_groups) for _ in range(num_groups)])
self.heads = nn.ModuleList([ml.Linear(dim, num_vectors) for _ in range(num_groups)])
def forward(self, codes): def forward(self, codes):
assert codes.shape[-1] == self.num_groups assert codes.shape[-1] == self.num_groups

View File

@ -3,6 +3,7 @@ import functools
import torch import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
import torch_intermediary as ml
from models.arch_util import zero_module from models.arch_util import zero_module
from models.vqvae.vqvae import Quantize from models.vqvae.vqvae import Quantize
@ -75,7 +76,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
self.codevectors = nn.Parameter( self.codevectors = nn.Parameter(
torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups) torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups)
) )
self.weight_proj = nn.Linear(proj_dim, self.num_groups * self.num_vars) self.weight_proj = ml.Linear(proj_dim, self.num_groups * self.num_vars)
# can be decayed for training # can be decayed for training
self.temperature = 2 self.temperature = 2

View File

@ -3,6 +3,7 @@ import functools
import torch import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
import torch_intermediary as ml
from models.arch_util import zero_module from models.arch_util import zero_module
from models.vqvae.vqvae import Quantize from models.vqvae.vqvae import Quantize
@ -87,7 +88,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
self.codevectors = nn.Parameter( self.codevectors = nn.Parameter(
torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups) torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups)
) )
self.weight_proj = nn.Linear(proj_dim, self.num_groups * self.num_vars) self.weight_proj = ml.Linear(proj_dim, self.num_groups * self.num_vars)
# can be decayed for training # can be decayed for training
self.temperature = 2 self.temperature = 2

View File

@ -1,3 +1,4 @@
import itertools import itertools
import os import os
import random import random
@ -7,6 +8,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchaudio import torchaudio
import torchvision import torchvision
import torch_intermediary as ml
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import TimestepBlock from models.diffusion.unet_diffusion import TimestepBlock
@ -54,12 +56,12 @@ class ConcatAttentionBlock(TimestepBlock):
self.prenorm = RMSScaleShiftNorm(trunk_dim, embed_dim=time_embed_dim, bias=False) self.prenorm = RMSScaleShiftNorm(trunk_dim, embed_dim=time_embed_dim, bias=False)
if cond_projection: if cond_projection:
self.tdim = trunk_dim+cond_dim_hidden self.tdim = trunk_dim+cond_dim_hidden
self.cond_project = nn.Linear(cond_dim_in, cond_dim_hidden) self.cond_project = ml.Linear(cond_dim_in, cond_dim_hidden)
else: else:
self.tdim = trunk_dim self.tdim = trunk_dim
self.block1 = SubBlock(self.tdim, contraction_dim, heads, dropout, use_conv) self.block1 = SubBlock(self.tdim, contraction_dim, heads, dropout, use_conv)
self.block2 = SubBlock(self.tdim+contraction_dim*2, contraction_dim, heads, dropout, use_conv) self.block2 = SubBlock(self.tdim+contraction_dim*2, contraction_dim, heads, dropout, use_conv)
self.out = nn.Linear(contraction_dim*4, trunk_dim, bias=False) self.out = ml.Linear(contraction_dim*4, trunk_dim, bias=False)
self.out.weight.data.zero_() self.out.weight.data.zero_()
def forward(self, x, cond, timestep_emb, rotary_emb): def forward(self, x, cond, timestep_emb, rotary_emb):
@ -87,7 +89,7 @@ class ConditioningEncoder(nn.Module):
self.init = nn.Conv1d(cond_dim, embedding_dim, kernel_size=1) self.init = nn.Conv1d(cond_dim, embedding_dim, kernel_size=1)
self.time_proj = time_proj self.time_proj = time_proj
if time_proj: if time_proj:
self.time_proj = nn.Linear(time_embed_dim, embedding_dim) self.time_proj = ml.Linear(time_embed_dim, embedding_dim)
self.attn = Encoder( self.attn = Encoder(
dim=embedding_dim, dim=embedding_dim,
depth=attn_blocks, depth=attn_blocks,

View File

@ -4,6 +4,7 @@ from time import time
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch_intermediary as ml
from models.arch_util import ResBlock from models.arch_util import ResBlock
from models.audio.music.gpt_music2 import UpperEncoder, GptMusicLower from models.audio.music.gpt_music2 import UpperEncoder, GptMusicLower
@ -27,7 +28,8 @@ def is_sequence(t):
class MultiGroupEmbedding(nn.Module): class MultiGroupEmbedding(nn.Module):
def __init__(self, tokens, groups, dim): def __init__(self, tokens, groups, dim):
super().__init__() super().__init__()
self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)]) # nn.Embedding
self.m = nn.ModuleList([ml.Embedding(tokens, dim // groups) for _ in range(groups)])
def forward(self, x): def forward(self, x):
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)] h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
@ -68,7 +70,7 @@ class ConcatAttentionBlock(TimestepBlock):
self.prenorm = RMSScaleShiftNorm(trunk_dim, embed_dim=time_embed_dim, bias=False) self.prenorm = RMSScaleShiftNorm(trunk_dim, embed_dim=time_embed_dim, bias=False)
self.block1 = SubBlock(trunk_dim, contraction_dim, heads, dropout) self.block1 = SubBlock(trunk_dim, contraction_dim, heads, dropout)
self.block2 = SubBlock(trunk_dim+contraction_dim*2, contraction_dim, heads, dropout) self.block2 = SubBlock(trunk_dim+contraction_dim*2, contraction_dim, heads, dropout)
self.out = nn.Linear(contraction_dim*4, trunk_dim, bias=False) self.out = ml.Linear(contraction_dim*4, trunk_dim, bias=False)
self.out.weight.data.zero_() self.out.weight.data.zero_()
def forward(self, x, timestep_emb, rotary_emb): def forward(self, x, timestep_emb, rotary_emb):
@ -129,7 +131,7 @@ class TransformerDiffusion(nn.Module):
) )
prenet_heads = prenet_channels//64 prenet_heads = prenet_channels//64
self.input_converter = nn.Linear(input_vec_dim, prenet_channels) self.input_converter = ml.Linear(input_vec_dim, prenet_channels)
self.code_converter = Encoder( self.code_converter = Encoder(
dim=prenet_channels, dim=prenet_channels,
depth=prenet_layers, depth=prenet_layers,
@ -145,7 +147,7 @@ class TransformerDiffusion(nn.Module):
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels)) self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels))
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
self.intg = nn.Linear(prenet_channels*2, model_channels) self.intg = ml.Linear(prenet_channels*2, model_channels)
self.layers = TimestepRotaryEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_embed_dim, num_heads, dropout) for _ in range(num_layers)]) self.layers = TimestepRotaryEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_embed_dim, num_heads, dropout) for _ in range(num_layers)])
self.out = nn.Sequential( self.out = nn.Sequential(

View File

@ -5,6 +5,7 @@ from random import randrange
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch_intermediary as ml
from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask, cGLU, \ from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask, cGLU, \
RelativeQKBias RelativeQKBias
@ -69,13 +70,14 @@ class ConditioningEncoder(nn.Module):
super().__init__() super().__init__()
attn = [] attn = []
self.init = nn.Conv1d(spec_dim, hidden_dim, kernel_size=5, stride=2) self.init = nn.Conv1d(spec_dim, hidden_dim, kernel_size=5, stride=2)
self.resolution_embedding = nn.Embedding(num_resolutions, hidden_dim) # nn.Embedding
self.resolution_embedding = ml.Embedding(num_resolutions, hidden_dim)
self.resolution_embedding.weight.data.mul(.1) # Reduces the relative influence of this embedding from the start. self.resolution_embedding.weight.data.mul(.1) # Reduces the relative influence of this embedding from the start.
for a in range(attn_blocks): for a in range(attn_blocks):
attn.append(AttentionBlock(hidden_dim, num_attn_heads, do_checkpoint=do_checkpointing)) attn.append(AttentionBlock(hidden_dim, num_attn_heads, do_checkpoint=do_checkpointing))
attn.append(ResBlock(hidden_dim, dims=1, checkpointing_enabled=do_checkpointing)) attn.append(ResBlock(hidden_dim, dims=1, checkpointing_enabled=do_checkpointing))
self.attn = nn.Sequential(*attn) self.attn = nn.Sequential(*attn)
self.out = nn.Linear(hidden_dim, out_dim, bias=False) self.out = ml.Linear(hidden_dim, out_dim, bias=False)
self.dim = hidden_dim self.dim = hidden_dim
self.do_checkpointing = do_checkpointing self.do_checkpointing = do_checkpointing
@ -131,7 +133,8 @@ class TransformerDiffusion(nn.Module):
nn.SiLU(), nn.SiLU(),
linear(time_embed_dim, time_proj_dim), linear(time_embed_dim, time_proj_dim),
) )
self.resolution_embed = nn.Embedding(resolution_steps, time_proj_dim) # nn.Embedding
self.resolution_embed = ml.Embedding(resolution_steps, time_proj_dim)
self.conditioning_encoder = ConditioningEncoder(in_channels, model_channels, cond_proj_dim, resolution_steps, num_attn_heads=model_channels//64) self.conditioning_encoder = ConditioningEncoder(in_channels, model_channels, cond_proj_dim, resolution_steps, num_attn_heads=model_channels//64)
self.unconditioned_embedding = nn.Parameter(torch.randn(1,cond_proj_dim)) self.unconditioned_embedding = nn.Parameter(torch.randn(1,cond_proj_dim))

View File

@ -8,6 +8,7 @@ import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchvision # For debugging, not actually used. import torchvision # For debugging, not actually used.
import torch_intermediary as ml
from models.audio.music.gpt_music import GptMusicLower from models.audio.music.gpt_music import GptMusicLower
from models.audio.music.music_quantizer import MusicQuantizer from models.audio.music.music_quantizer import MusicQuantizer
@ -490,7 +491,7 @@ class UNetMusicModel(nn.Module):
) )
if self.ar_prior: if self.ar_prior:
self.ar_input = nn.Linear(input_vec_dim, model_channels) self.ar_input = ml.Linear(input_vec_dim, model_channels)
self.ar_prior_intg = Encoder( self.ar_prior_intg = Encoder(
dim=model_channels, dim=model_channels,
depth=4, depth=4,
@ -504,7 +505,7 @@ class UNetMusicModel(nn.Module):
ff_mult=1, ff_mult=1,
) )
else: else:
self.input_converter = nn.Linear(input_vec_dim, model_channels) self.input_converter = ml.Linear(input_vec_dim, model_channels)
self.code_converter = Encoder( self.code_converter = Encoder(
dim=model_channels, dim=model_channels,
depth=4, depth=4,
@ -521,7 +522,8 @@ class UNetMusicModel(nn.Module):
self.x_processor = conv_nd(dims, in_channels, model_channels, 3, padding=1) self.x_processor = conv_nd(dims, in_channels, model_channels, 3, padding=1)
if self.num_classes is not None: if self.num_classes is not None:
self.label_emb = nn.Embedding(num_classes, time_embed_dim) # nn.Embedding
self.label_emb = ml.Embedding(num_classes, time_embed_dim)
self.use_raw_y_as_embedding = use_raw_y_as_embedding self.use_raw_y_as_embedding = use_raw_y_as_embedding
assert not ((self.num_classes is not None) and use_raw_y_as_embedding) # These are mutually-exclusive. assert not ((self.num_classes is not None) and use_raw_y_as_embedding) # These are mutually-exclusive.

View File

@ -3,6 +3,7 @@ from random import random
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch_intermediary as ml
from models.audio.tts.unet_diffusion_tts7 import CheckpointedLayer from models.audio.tts.unet_diffusion_tts7 import CheckpointedLayer
from models.lucidrains.x_transformers import Encoder from models.lucidrains.x_transformers import Encoder
@ -36,9 +37,12 @@ class CtcCodeGenerator(nn.Module):
self.ctc_codes = ctc_codes self.ctc_codes = ctc_codes
pred_codes = (max_pad+1)*(max_repeat+1) pred_codes = (max_pad+1)*(max_repeat+1)
self.position_embedding = nn.Embedding(max_length, model_dim) # nn.Embedding
self.codes_embedding = nn.Embedding(ctc_codes, model_dim) self.position_embedding = ml.Embedding(max_length, model_dim)
self.recursive_embedding = nn.Embedding(pred_codes, model_dim) # nn.Embedding
self.codes_embedding = ml.Embedding(ctc_codes, model_dim)
# nn.Embedding
self.recursive_embedding = ml.Embedding(pred_codes, model_dim)
self.mask_embedding = nn.Parameter(torch.randn(model_dim)) self.mask_embedding = nn.Parameter(torch.randn(model_dim))
self.encoder = Encoder( self.encoder = Encoder(
dim=model_dim, dim=model_dim,
@ -50,8 +54,8 @@ class CtcCodeGenerator(nn.Module):
ff_glu=True, ff_glu=True,
rotary_pos_emb=True, rotary_pos_emb=True,
) )
self.pred_head = nn.Linear(model_dim, pred_codes) self.pred_head = ml.Linear(model_dim, pred_codes)
self.confidence_head = nn.Linear(model_dim, 1) self.confidence_head = ml.Linear(model_dim, 1)
def inference(self, codes, pads, repeats): def inference(self, codes, pads, repeats):
position_h = self.position_embedding(torch.arange(0, codes.shape[-1], device=codes.device)) position_h = self.position_embedding(torch.arange(0, codes.shape[-1], device=codes.device))

View File

@ -5,6 +5,8 @@ from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch_intermediary as ml
from x_transformers.x_transformers import groupby_prefix_and_trim, FixedPositionalEmbedding, default, RotaryEmbedding, \ from x_transformers.x_transformers import groupby_prefix_and_trim, FixedPositionalEmbedding, default, RotaryEmbedding, \
DEFAULT_DIM_HEAD, RelativePositionBias, LearnedAlibiPositionalBias, AlibiPositionalBias, ScaleNorm, RMSNorm, \ DEFAULT_DIM_HEAD, RelativePositionBias, LearnedAlibiPositionalBias, AlibiPositionalBias, ScaleNorm, RMSNorm, \
exists, Attention, FeedForward, Scale, ShiftTokens, GRUGating, Residual, cast_tuple, equals, LayerIntermediates, \ exists, Attention, FeedForward, Scale, ShiftTokens, GRUGating, Residual, cast_tuple, equals, LayerIntermediates, \
@ -16,7 +18,7 @@ class TimeIntegrationBlock(nn.Module):
super().__init__() super().__init__()
self.emb_layers = nn.Sequential( self.emb_layers = nn.Sequential(
nn.SiLU(), nn.SiLU(),
nn.Linear( ml.Linear(
time_emb_dim, time_emb_dim,
2 * dim 2 * dim
), ),

View File

@ -1,5 +1,6 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch_intermediary as ml
from models.diffusion.nn import normalization, conv_nd, zero_module from models.diffusion.nn import normalization, conv_nd, zero_module
@ -138,7 +139,7 @@ class AudioMiniEncoderWithClassifierHead(nn.Module):
def __init__(self, classes, distribute_zero_label=True, **kwargs): def __init__(self, classes, distribute_zero_label=True, **kwargs):
super().__init__() super().__init__()
self.enc = AudioMiniEncoder(**kwargs) self.enc = AudioMiniEncoder(**kwargs)
self.head = nn.Linear(self.enc.dim, classes) self.head = ml.Linear(self.enc.dim, classes)
self.num_classes = classes self.num_classes = classes
self.distribute_zero_label = distribute_zero_label self.distribute_zero_label = distribute_zero_label
@ -183,7 +184,7 @@ class QueryProvidedAttentionBlock(nn.Module):
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels self.num_heads = channels // num_head_channels
self.norm = normalization(channels) self.norm = normalization(channels)
self.q = nn.Linear(channels, channels) self.q = ml.Linear(channels, channels)
self.qnorm = nn.LayerNorm(channels) self.qnorm = nn.LayerNorm(channels)
self.kv = conv_nd(1, channels, channels*2, 1) self.kv = conv_nd(1, channels, channels*2, 1)
if use_new_attention_order: if use_new_attention_order:

View File

@ -3,6 +3,7 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch_intermediary as ml
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import opt_get from utils.util import opt_get
@ -44,7 +45,7 @@ class RandomLatentConverter(nn.Module):
def __init__(self, channels): def __init__(self, channels):
super().__init__() super().__init__()
self.layers = nn.Sequential(*[EqualLinear(channels, channels, lr_mul=.1) for _ in range(5)], self.layers = nn.Sequential(*[EqualLinear(channels, channels, lr_mul=.1) for _ in range(5)],
nn.Linear(channels, channels)) ml.Linear(channels, channels))
self.channels = channels self.channels = channels
def forward(self, ref): def forward(self, ref):

View File

@ -3,12 +3,13 @@ from librosa.filters import mel as librosa_mel_fn
from models.audio.tts.tacotron2.audio_processing import dynamic_range_compression from models.audio.tts.tacotron2.audio_processing import dynamic_range_compression
from models.audio.tts.tacotron2.audio_processing import dynamic_range_decompression from models.audio.tts.tacotron2.audio_processing import dynamic_range_decompression
from models.audio.tts.tacotron2.stft import STFT from models.audio.tts.tacotron2.stft import STFT
import torch_intermediary as ml
class LinearNorm(torch.nn.Module): class LinearNorm(torch.nn.Module):
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
super(LinearNorm, self).__init__() super(LinearNorm, self).__init__()
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) self.linear_layer = torch.ml.Linear(in_dim, out_dim, bias=bias)
torch.nn.init.xavier_uniform_( torch.nn.init.xavier_uniform_(
self.linear_layer.weight, self.linear_layer.weight,

View File

@ -8,6 +8,7 @@ from models.audio.tts.tacotron2.layers import ConvNorm, LinearNorm
from models.audio.tts.tacotron2.hparams import create_hparams from models.audio.tts.tacotron2.hparams import create_hparams
from trainer.networks import register_model from trainer.networks import register_model
from models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths from models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths
import torch_intermediary as ml
class LocationLayer(nn.Module): class LocationLayer(nn.Module):
@ -463,7 +464,8 @@ class Tacotron2(nn.Module):
self.fp16_run = hparams.fp16_run self.fp16_run = hparams.fp16_run
self.n_mel_channels = hparams.n_mel_channels self.n_mel_channels = hparams.n_mel_channels
self.n_frames_per_step = hparams.n_frames_per_step self.n_frames_per_step = hparams.n_frames_per_step
self.embedding = nn.Embedding( # nn.Embedding
self.embedding = ml.Embedding(
hparams.n_symbols, hparams.symbols_embedding_dim) hparams.n_symbols, hparams.symbols_embedding_dim)
std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim)) std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
val = sqrt(3.0) * std # uniform bounds for std val = sqrt(3.0) * std # uniform bounds for std

View File

@ -13,6 +13,7 @@ from models.audio.tts.tacotron2.tacotron2 import Attention, Encoder
from trainer.networks import register_model from trainer.networks import register_model
from models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths from models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths
from utils.util import checkpoint from utils.util import checkpoint
import torch_intermediary as ml
@ -185,7 +186,8 @@ class WaveTacotron2(nn.Module):
self.fp16_run = hparams.fp16_run self.fp16_run = hparams.fp16_run
self.n_mel_channels = hparams.n_mel_channels self.n_mel_channels = hparams.n_mel_channels
self.n_frames_per_step = hparams.n_frames_per_step self.n_frames_per_step = hparams.n_frames_per_step
self.embedding = nn.Embedding( # nn.Embedding
self.embedding = ml.Embedding(
hparams.n_symbols, hparams.symbols_embedding_dim) hparams.n_symbols, hparams.symbols_embedding_dim)
std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim)) std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
val = sqrt(3.0) * std # uniform bounds for std val = sqrt(3.0) * std # uniform bounds for std

View File

@ -25,6 +25,7 @@ import random
from time import time from time import time
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch_intermediary as ml
from tqdm import tqdm from tqdm import tqdm
@ -35,7 +36,8 @@ def null_position_embeddings(range, dim):
class LearnedPositionEmbeddings(nn.Module): class LearnedPositionEmbeddings(nn.Module):
def __init__(self, seq_len, model_dim, init=.02, relative=False): def __init__(self, seq_len, model_dim, init=.02, relative=False):
super().__init__() super().__init__()
self.emb = nn.Embedding(seq_len, model_dim) # nn.Embedding
self.emb = ml.Embedding(seq_len, model_dim)
# Initializing this way is standard for GPT-2 # Initializing this way is standard for GPT-2
self.emb.weight.data.normal_(mean=0.0, std=init) self.emb.weight.data.normal_(mean=0.0, std=init)
self.relative = relative self.relative = relative

View File

@ -7,6 +7,7 @@ from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlo
from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import checkpoint from utils.util import checkpoint
import torch_intermediary as ml
def is_latent(t): def is_latent(t):
@ -19,7 +20,8 @@ def is_sequence(t):
class MultiGroupEmbedding(nn.Module): class MultiGroupEmbedding(nn.Module):
def __init__(self, tokens, groups, dim): def __init__(self, tokens, groups, dim):
super().__init__() super().__init__()
self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)]) # nn.Embedding
self.m = nn.ModuleList([ml.Embedding(tokens, dim // groups) for _ in range(groups)])
def forward(self, x): def forward(self, x):
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)] h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
@ -100,15 +102,17 @@ class TransformerDiffusionTTS(nn.Module):
ff_glu=True, ff_glu=True,
rotary_pos_emb=True, rotary_pos_emb=True,
) )
self.clvp_encoder = nn.Linear(clvp_in_dim, model_channels) self.clvp_encoder = ml.Linear(clvp_in_dim, model_channels)
self.type_embedding = nn.Embedding(types, model_channels) # nn.Embedding
self.type_embedding = ml.Embedding(types, model_channels)
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed. # Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally # This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive # complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
# transformer network. # transformer network.
if in_groups is None: if in_groups is None:
self.embeddings = nn.Embedding(token_count, model_channels) # nn.Embedding
self.embeddings = ml.Embedding(token_count, model_channels)
else: else:
self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels) self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels)
self.latent_conditioner = nn.Sequential( self.latent_conditioner = nn.Sequential(
@ -140,7 +144,7 @@ class TransformerDiffusionTTS(nn.Module):
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1) self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
self.intg = nn.Linear(model_channels*2, model_channels) self.intg = ml.Linear(model_channels*2, model_channels)
self.layers = TimestepRotaryEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers)]) self.layers = TimestepRotaryEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers)])
self.out = nn.Sequential( self.out = nn.Sequential(

View File

@ -1,6 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch_intermediary as ml
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock
@ -19,7 +20,8 @@ def is_sequence(t):
class MultiGroupEmbedding(nn.Module): class MultiGroupEmbedding(nn.Module):
def __init__(self, tokens, groups, dim): def __init__(self, tokens, groups, dim):
super().__init__() super().__init__()
self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)]) # nn.Embedding
self.m = nn.ModuleList([ml.Embedding(tokens, dim // groups) for _ in range(groups)])
def forward(self, x): def forward(self, x):
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)] h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
@ -40,7 +42,7 @@ class DietAttentionBlock(TimestepBlock):
def __init__(self, in_dim, dim, heads, dropout): def __init__(self, in_dim, dim, heads, dropout):
super().__init__() super().__init__()
self.rms_scale_norm = RMSScaleShiftNorm(in_dim) self.rms_scale_norm = RMSScaleShiftNorm(in_dim)
self.proj = nn.Linear(in_dim, dim) self.proj = ml.Linear(in_dim, dim)
self.attn = Attention(dim, heads=heads, causal=False, dropout=dropout) self.attn = Attention(dim, heads=heads, causal=False, dropout=dropout)
self.ff = FeedForward(dim, in_dim, mult=1, dropout=dropout, zero_init_output=True) self.ff = FeedForward(dim, in_dim, mult=1, dropout=dropout, zero_init_output=True)
@ -105,15 +107,17 @@ class TransformerDiffusionTTS(nn.Module):
ff_glu=True, ff_glu=True,
rotary_pos_emb=True, rotary_pos_emb=True,
) )
self.clvp_encoder = nn.Linear(clvp_in_dim, prenet_channels) self.clvp_encoder = ml.Linear(clvp_in_dim, prenet_channels)
self.type_embedding = nn.Embedding(types, prenet_channels) # nn.Embedding
self.type_embedding = ml.Embedding(types, prenet_channels)
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed. # Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally # This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive # complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
# transformer network. # transformer network.
if in_groups is None: if in_groups is None:
self.embeddings = nn.Embedding(token_count, prenet_channels) # nn.Embedding
self.embeddings = ml.Embedding(token_count, prenet_channels)
else: else:
self.embeddings = MultiGroupEmbedding(token_count, in_groups, prenet_channels) self.embeddings = MultiGroupEmbedding(token_count, in_groups, prenet_channels)
self.latent_conditioner = nn.Sequential( self.latent_conditioner = nn.Sequential(
@ -144,8 +148,8 @@ class TransformerDiffusionTTS(nn.Module):
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels)) self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels))
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
self.cond_intg = nn.Linear(prenet_channels*4, model_channels) self.cond_intg = ml.Linear(prenet_channels*4, model_channels)
self.intg = nn.Linear(prenet_channels*2, model_channels) self.intg = ml.Linear(prenet_channels*2, model_channels)
self.layers = TimestepRotaryEmbedSequential(*[DietAttentionBlock(model_channels, block_channels, block_channels // 64, dropout) for _ in range(num_layers)]) self.layers = TimestepRotaryEmbedSequential(*[DietAttentionBlock(model_channels, block_channels, block_channels // 64, dropout) for _ in range(num_layers)])

View File

@ -5,6 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import autocast from torch import autocast
import torch_intermediary as ml
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \ from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \
@ -247,14 +248,16 @@ class DiffusionTts(nn.Module):
) )
embedding_dim = model_channels * 8 embedding_dim = model_channels * 8
self.code_embedding = nn.Embedding(num_tokens+1, embedding_dim) # nn.Embedding
self.code_embedding = ml.Embedding(num_tokens+1, embedding_dim)
self.contextual_embedder = AudioMiniEncoder(1, embedding_dim, base_channels=32, depth=6, resnet_blocks=1, self.contextual_embedder = AudioMiniEncoder(1, embedding_dim, base_channels=32, depth=6, resnet_blocks=1,
attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5) attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
self.conditioning_conv = nn.Conv1d(embedding_dim*3, embedding_dim, 1) self.conditioning_conv = nn.Conv1d(embedding_dim*3, embedding_dim, 1)
self.enable_unaligned_inputs = enabled_unaligned_inputs self.enable_unaligned_inputs = enabled_unaligned_inputs
if enabled_unaligned_inputs: if enabled_unaligned_inputs:
self.unaligned_embedder = nn.Embedding(num_unaligned_tokens, embedding_dim) # nn.Embedding
self.unaligned_embedder = ml.Embedding(num_unaligned_tokens, embedding_dim)
self.unaligned_encoder = CheckpointedXTransformerEncoder( self.unaligned_encoder = CheckpointedXTransformerEncoder(
max_seq_len=-1, max_seq_len=-1,
use_pos_emb=False, use_pos_emb=False,

View File

@ -5,6 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import autocast from torch import autocast
from x_transformers import Encoder from x_transformers import Encoder
import torch_intermediary as ml
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \ from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \
@ -206,7 +207,8 @@ class DiffusionTts(nn.Module):
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive # complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
# transformer network. # transformer network.
self.code_converter = nn.Sequential( self.code_converter = nn.Sequential(
nn.Embedding(in_tokens, conditioning_dim), # nn.Embedding
ml.Embedding(in_tokens, conditioning_dim),
CheckpointedXTransformerEncoder( CheckpointedXTransformerEncoder(
needs_permute=False, needs_permute=False,
max_seq_len=-1, max_seq_len=-1,

View File

@ -5,6 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import autocast from torch import autocast
import torch_intermediary as ml
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock, QKVAttentionLegacy from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock, QKVAttentionLegacy
@ -193,7 +194,9 @@ class DiffusionTtsFlat(nn.Module):
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally # This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive # complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
# transformer network. # transformer network.
self.code_embedding = nn.Embedding(in_tokens, model_channels)
# nn.Embedding
self.code_embedding = ml.Embedding(in_tokens, model_channels)
self.code_converter = nn.Sequential( self.code_converter = nn.Sequential(
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),

View File

@ -1,6 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import GPT2Config, GPT2PreTrainedModel from transformers import GPT2Config, GPT2PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
@ -12,6 +13,7 @@ from models.lucidrains.x_transformers import RotaryEmbedding, apply_rotary_pos_e
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import opt_get from utils.util import opt_get
import torch_intermediary as ml
class ResBlock(nn.Module): class ResBlock(nn.Module):
""" """
@ -279,9 +281,11 @@ class UnifiedVoice(nn.Module):
self.mel_length_compression = mel_length_compression self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.average_conditioning_embeddings = average_conditioning_embeddings self.average_conditioning_embeddings = average_conditioning_embeddings
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim) # nn.Embedding
self.text_embedding = ml.Embedding(self.number_text_tokens, model_dim)
if use_mel_codes_as_input: if use_mel_codes_as_input:
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim) # nn.Embedding
self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim)
else: else:
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1) self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \ self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
@ -294,8 +298,8 @@ class UnifiedVoice(nn.Module):
self.text_solo_embedding = 0 self.text_solo_embedding = 0
self.final_norm = nn.LayerNorm(model_dim) self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.number_text_tokens) self.text_head = ml.Linear(model_dim, self.number_text_tokens)
self.mel_head = nn.Linear(model_dim, self.number_mel_codes) self.mel_head = ml.Linear(model_dim, self.number_mel_codes)
# Initialize the embeddings per the GPT-2 scheme # Initialize the embeddings per the GPT-2 scheme
embeddings = [self.text_embedding] embeddings = [self.text_embedding]

View File

@ -1,6 +1,8 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch_intermediary as ml
from transformers import GPT2Config, GPT2PreTrainedModel from transformers import GPT2Config, GPT2PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
@ -271,15 +273,17 @@ class UnifiedVoice(nn.Module):
self.model_dim = model_dim self.model_dim = model_dim
self.mel_length_compression = mel_length_compression self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim) # nn.Embedding
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim) self.text_embedding = ml.Embedding(self.number_text_tokens*types+1, model_dim)
# nn.Embedding
self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim)
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \ self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing) build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing)
self.final_norm = nn.LayerNorm(model_dim) self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1) self.text_head = ml.Linear(model_dim, self.number_text_tokens*types+1)
self.mel_head = nn.Linear(model_dim, self.number_mel_codes) self.mel_head = ml.Linear(model_dim, self.number_mel_codes)
self.aligned_head = nn.Linear(model_dim, number_aligned_text_codes) self.aligned_head = ml.Linear(model_dim, number_aligned_text_codes)
# Initialize the embeddings per the GPT-2 scheme # Initialize the embeddings per the GPT-2 scheme
embeddings = [self.text_embedding, self.mel_embedding] embeddings = [self.text_embedding, self.mel_embedding]

View File

@ -11,6 +11,7 @@ from models.audio.tts.transformer_builders import build_hf_gpt_transformer
from models.lucidrains.x_transformers import RotaryEmbedding, apply_rotary_pos_emb from models.lucidrains.x_transformers import RotaryEmbedding, apply_rotary_pos_emb
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import opt_get from utils.util import opt_get
import torch_intermediary as ml
class ResBlock(nn.Module): class ResBlock(nn.Module):
@ -255,15 +256,17 @@ class UnifiedVoice(nn.Module):
self.model_dim = model_dim self.model_dim = model_dim
self.mel_length_compression = mel_length_compression self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim) # nn.Embedding
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim) self.text_embedding = ml.Embedding(self.number_text_tokens*types+1, model_dim)
# nn.Embedding
self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim)
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \ self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing) build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing)
self.final_norm = nn.LayerNorm(model_dim) self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1) self.text_head = ml.Linear(model_dim, self.number_text_tokens*types+1)
self.mel_head = nn.Linear(model_dim, self.number_mel_codes) self.mel_head = ml.Linear(model_dim, self.number_mel_codes)
self.alignment_head = nn.Linear(model_dim, 256) self.alignment_head = ml.Linear(model_dim, 256)
if only_alignment_head: if only_alignment_head:
for p in self.parameters(): for p in self.parameters():

View File

@ -8,6 +8,7 @@ from models.audio.tts.mini_encoder import AudioMiniEncoder
from trainer.injectors.spec_augment import spec_augment from trainer.injectors.spec_augment import spec_augment
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import opt_get from utils.util import opt_get
import torch_intermediary as ml
def exists(val): def exists(val):
@ -36,7 +37,7 @@ class VoiceCLIP(nn.Module):
self.encoder = AudioMiniEncoder(80, encoder_output) self.encoder = AudioMiniEncoder(80, encoder_output)
if pretrained_encoder_dict_path is not None: if pretrained_encoder_dict_path is not None:
self.encoder.load_state_dict(torch.load(pretrained_encoder_dict_path)) self.encoder.load_state_dict(torch.load(pretrained_encoder_dict_path))
self.to_latent = nn.Linear(encoder_output, dim_latent, bias=False) self.to_latent = ml.Linear(encoder_output, dim_latent, bias=False)
self.temperature = nn.Parameter(torch.tensor(1.)) self.temperature = nn.Parameter(torch.tensor(1.))
self.mel_compression_ratio = mel_compression_ratio self.mel_compression_ratio = mel_compression_ratio

View File

@ -7,6 +7,7 @@ from x_transformers import Encoder, Decoder, ContinuousTransformerWrapper
from models.audio.tts.mini_encoder import AudioMiniEncoder from models.audio.tts.mini_encoder import AudioMiniEncoder
from trainer.networks import register_model from trainer.networks import register_model
import torch_intermediary as ml
class CheckpointedLayer(nn.Module): class CheckpointedLayer(nn.Module):
@ -56,7 +57,8 @@ class Wav2VecMatcher(nn.Module):
WAV2VEC_CHANNELS = 1024 WAV2VEC_CHANNELS = 1024
self.conditioning_encoder = AudioMiniEncoder(1, model_dim, base_channels=32, depth=6, resnet_blocks=1, self.conditioning_encoder = AudioMiniEncoder(1, model_dim, base_channels=32, depth=6, resnet_blocks=1,
attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5) attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
self.text_embedding = nn.Embedding(num_text_tokens, model_dim) # nn.Embedding
self.text_embedding = ml.Embedding(num_text_tokens, model_dim)
self.encoder = CheckpointedXTransformer( self.encoder = CheckpointedXTransformer(
max_seq_len=-1, max_seq_len=-1,
use_pos_emb=False, use_pos_emb=False,
@ -73,8 +75,8 @@ class Wav2VecMatcher(nn.Module):
) )
self.decoder_start_embedding = nn.Parameter(torch.randn(1,1,model_dim)) self.decoder_start_embedding = nn.Parameter(torch.randn(1,1,model_dim))
self.decoder_stop_embedding = nn.Parameter(torch.randn(1,model_dim)) self.decoder_stop_embedding = nn.Parameter(torch.randn(1,model_dim))
self.w2v_query_encoder = nn.Linear(WAV2VEC_CHANNELS, model_dim) self.w2v_query_encoder = ml.Linear(WAV2VEC_CHANNELS, model_dim)
self.w2v_value_encoder = nn.Linear(WAV2VEC_CHANNELS, model_dim) self.w2v_value_encoder = ml.Linear(WAV2VEC_CHANNELS, model_dim)
self.decoder = CheckpointedXTransformer( self.decoder = CheckpointedXTransformer(
max_seq_len=-1, # Should be unused max_seq_len=-1, # Should be unused
use_pos_emb=False, use_pos_emb=False,

View File

@ -10,6 +10,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch_intermediary as ml
from trainer.networks import register_model from trainer.networks import register_model
@ -98,7 +99,7 @@ class ResNet(nn.Module):
self.conv4_x = self._make_layer(block, 128, num_block[2], 2) self.conv4_x = self._make_layer(block, 128, num_block[2], 2)
self.conv5_x = self._make_layer(block, 256, num_block[3], 2) self.conv5_x = self._make_layer(block, 256, num_block[3], 2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(256 * block.expansion, num_classes) self.fc = ml.Linear(256 * block.expansion, num_classes)
def _make_layer(self, block, out_channels, num_blocks, stride): def _make_layer(self, block, out_channels, num_blocks, stride):
"""make resnet layers(by layer i didnt mean this 'layer' was the """make resnet layers(by layer i didnt mean this 'layer' was the

View File

@ -4,6 +4,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torchvision.models.resnet import BasicBlock, Bottleneck from torchvision.models.resnet import BasicBlock, Bottleneck
import torchvision import torchvision
import torch_intermediary as ml
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
@ -194,5 +195,5 @@ def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
def register_resnet50(opt_net, opt): def register_resnet50(opt_net, opt):
model = resnet50(pretrained=opt_net['pretrained']) model = resnet50(pretrained=opt_net['pretrained'])
if opt_net['custom_head_logits']: if opt_net['custom_head_logits']:
model.fc = nn.Linear(512 * 4, opt_net['custom_head_logits']) model.fc = ml.Linear(512 * 4, opt_net['custom_head_logits'])
return model return model

View File

@ -11,6 +11,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch_intermediary as ml
from trainer.networks import register_model from trainer.networks import register_model
@ -101,7 +102,7 @@ class ResNet(nn.Module):
self.conv4_x = self._make_layer(block, 128, num_block[2], 2) self.conv4_x = self._make_layer(block, 128, num_block[2], 2)
self.conv5_x = self._make_layer(block, 256, num_block[3], 2) self.conv5_x = self._make_layer(block, 256, num_block[3], 2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(256 * block.expansion, num_classes) self.fc = ml.Linear(256 * block.expansion, num_classes)
def _make_layer(self, block, out_channels, num_blocks, stride): def _make_layer(self, block, out_channels, num_blocks, stride):
"""make resnet layers(by layer i didnt mean this 'layer' was the """make resnet layers(by layer i didnt mean this 'layer' was the

View File

@ -11,6 +11,7 @@ __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
from models.vqvae.scaled_weight_conv import ScaledWeightConv from models.vqvae.scaled_weight_conv import ScaledWeightConv
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import checkpoint from utils.util import checkpoint
import torch_intermediary as ml
model_urls = { model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
@ -213,7 +214,7 @@ class ResNet(nn.Module):
self.layer4 = self._make_layer(block, 512, layers[3], breadth, stride=2, self.layer4 = self._make_layer(block, 512, layers[3], breadth, stride=2,
dilate=replace_stride_with_dilation[2]) dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes) self.fc = ml.Linear(512 * block.expansion, num_classes)
for m in self.modules(): for m in self.modules():
if isinstance(m, ScaledWeightConv): if isinstance(m, ScaledWeightConv):

View File

@ -3,7 +3,7 @@ import torch.nn as nn
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import opt_get from utils.util import opt_get
import torch_intermediary as ml
class WideKernelVgg(nn.Module): class WideKernelVgg(nn.Module):
def __init__(self, nf=64, num_classes=2): def __init__(self, nf=64, num_classes=2):
@ -49,9 +49,9 @@ class WideKernelVgg(nn.Module):
nn.ReLU(), nn.ReLU(),
nn.MaxPool2d(kernel_size=2), nn.MaxPool2d(kernel_size=2),
nn.Flatten(), nn.Flatten(),
nn.Linear(nf * 8 * 4 * 2, 100), ml.Linear(nf * 8 * 4 * 2, 100),
nn.ReLU(), nn.ReLU(),
nn.Linear(100, num_classes) ml.Linear(100, num_classes)
) )
# These normalization constants should be derived experimentally. # These normalization constants should be derived experimentally.

View File

@ -10,6 +10,7 @@ from models.arch_util import AttentionBlock
from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import opt_get, checkpoint from utils.util import opt_get, checkpoint
import torch_intermediary as ml
def exists(val): def exists(val):
@ -58,7 +59,8 @@ class CollapsingTransformer(nn.Module):
class ConvFormatEmbedding(nn.Module): class ConvFormatEmbedding(nn.Module):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__() super().__init__()
self.emb = nn.Embedding(*args, **kwargs) # nn.Embedding
self.emb = ml.Embedding(*args, **kwargs)
def forward(self, x): def forward(self, x):
y = self.emb(x) y = self.emb(x)
@ -98,9 +100,10 @@ class CLVP(nn.Module):
self.masked_conditioning_latent = nn.Parameter(torch.randn(1,model_dim*2), requires_grad=True) self.masked_conditioning_latent = nn.Parameter(torch.randn(1,model_dim*2), requires_grad=True)
self.mask_conditioning_percentage = mask_conditioning_percentage self.mask_conditioning_percentage = mask_conditioning_percentage
self.text_emb = nn.Embedding(num_text_tokens, model_dim) # nn.Embedding
self.text_emb = ml.Embedding(num_text_tokens, model_dim)
self.text_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, text_enc_depth, text_mask_percentage, use_rms_scaleshift_norm=True) self.text_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, text_enc_depth, text_mask_percentage, use_rms_scaleshift_norm=True)
self.to_text_latent = nn.Linear(latent_dim, latent_dim, bias=False) self.to_text_latent = ml.Linear(latent_dim, latent_dim, bias=False)
self.distributed_collect = distributed_collect self.distributed_collect = distributed_collect
if mel_codes is None: if mel_codes is None:
@ -108,7 +111,7 @@ class CLVP(nn.Module):
else: else:
self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim) self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim)
self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage) self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage)
self.to_speech_latent = nn.Linear(latent_dim, latent_dim, bias=False) self.to_speech_latent = ml.Linear(latent_dim, latent_dim, bias=False)
def get_grad_norm_parameter_groups(self): def get_grad_norm_parameter_groups(self):
return { return {

View File

@ -9,6 +9,7 @@ from models.arch_util import AttentionBlock
from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import opt_get, checkpoint from utils.util import opt_get, checkpoint
import torch_intermediary as ml
def exists(val): def exists(val):
@ -178,7 +179,8 @@ class CollapsingTransformer(nn.Module):
class ConvFormatEmbedding(nn.Module): class ConvFormatEmbedding(nn.Module):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__() super().__init__()
self.emb = nn.Embedding(*args, **kwargs) # nn.Embedding
self.emb = ml.Embedding(*args, **kwargs)
def forward(self, x): def forward(self, x):
y = self.emb(x) y = self.emb(x)
@ -203,8 +205,8 @@ class ContrastiveAudio(nn.Module):
self.emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim // 2, kernel_size=5, stride=2, padding=2), self.emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim // 2, kernel_size=5, stride=2, padding=2),
nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1)) nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1))
self.transformer = CollapsingTransformer(model_dim, model_dim, transformer_heads, dropout, encoder_depth, mask_percent) self.transformer = CollapsingTransformer(model_dim, model_dim, transformer_heads, dropout, encoder_depth, mask_percent)
self.to_latent = nn.Linear(latent_dim, latent_dim, bias=False) self.to_latent = ml.Linear(latent_dim, latent_dim, bias=False)
self.to_latent2 = nn.Linear(latent_dim, latent_dim, bias=False) self.to_latent2 = ml.Linear(latent_dim, latent_dim, bias=False)
self.to_latent2.weight.data = self.to_latent.weight.data self.to_latent2.weight.data = self.to_latent.weight.data
self.to_latent2.weight.DO_NOT_TRAIN = True self.to_latent2.weight.DO_NOT_TRAIN = True

View File

@ -10,6 +10,7 @@ from models.arch_util import AttentionBlock
from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import opt_get, checkpoint from utils.util import opt_get, checkpoint
import torch_intermediary as ml
def exists(val): def exists(val):
@ -58,7 +59,8 @@ class CollapsingTransformer(nn.Module):
class ConvFormatEmbedding(nn.Module): class ConvFormatEmbedding(nn.Module):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__() super().__init__()
self.emb = nn.Embedding(*args, **kwargs) # nn.Embedding
self.emb = ml.Embedding(*args, **kwargs)
def forward(self, x): def forward(self, x):
y = self.emb(x) y = self.emb(x)
@ -86,14 +88,14 @@ class CVVP(nn.Module):
self.cond_emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim//2, kernel_size=5, stride=2, padding=2), self.cond_emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim//2, kernel_size=5, stride=2, padding=2),
nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1)) nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1))
self.conditioning_transformer = CollapsingTransformer(model_dim, model_dim, transformer_heads, dropout, conditioning_enc_depth, cond_mask_percentage) self.conditioning_transformer = CollapsingTransformer(model_dim, model_dim, transformer_heads, dropout, conditioning_enc_depth, cond_mask_percentage)
self.to_conditioning_latent = nn.Linear(latent_dim, latent_dim, bias=False) self.to_conditioning_latent = ml.Linear(latent_dim, latent_dim, bias=False)
if mel_codes is None: if mel_codes is None:
self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2) self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2)
else: else:
self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim) self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim)
self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage) self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage)
self.to_speech_latent = nn.Linear(latent_dim, latent_dim, bias=False) self.to_speech_latent = ml.Linear(latent_dim, latent_dim, bias=False)
def get_grad_norm_parameter_groups(self): def get_grad_norm_parameter_groups(self):
return { return {

View File

@ -7,6 +7,7 @@ from torch import einsum
from models.lucidrains.dalle.transformer import Transformer from models.lucidrains.dalle.transformer import Transformer
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import opt_get from utils.util import opt_get
import torch_intermediary as ml
def exists(val): def exists(val):
@ -45,17 +46,20 @@ class MelTextCLIP(nn.Module):
mel_compression=256, mel_compression=256,
): ):
super().__init__() super().__init__()
self.text_emb = nn.Embedding(num_text_tokens, dim_text) # nn.Embedding
self.text_pos_emb = nn.Embedding(text_seq_len, dim_text) self.text_emb = ml.Embedding(num_text_tokens, dim_text)
# nn.Embedding
self.text_pos_emb = ml.Embedding(text_seq_len, dim_text)
self.text_transformer = Transformer(causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth, self.text_transformer = Transformer(causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth,
heads=text_heads, rotary_emb=False) heads=text_heads, rotary_emb=False)
self.to_text_latent = nn.Linear(dim_text, dim_latent, bias=False) self.to_text_latent = ml.Linear(dim_text, dim_latent, bias=False)
self.speech_enc = nn.Conv1d(80, dim_speech, kernel_size=3, padding=1) self.speech_enc = nn.Conv1d(80, dim_speech, kernel_size=3, padding=1)
self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech) # nn.Embedding
self.speech_pos_emb = ml.Embedding(num_speech_tokens, dim_speech)
self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech, self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech,
depth=speech_enc_depth, heads=speech_heads, rotary_emb=False) depth=speech_enc_depth, heads=speech_heads, rotary_emb=False)
self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False) self.to_speech_latent = ml.Linear(dim_speech, dim_latent, bias=False)
self.temperature = nn.Parameter(torch.tensor(1.)) self.temperature = nn.Parameter(torch.tensor(1.))
self.text_mask_percentage = text_mask_percentage self.text_mask_percentage = text_mask_percentage

View File

@ -7,6 +7,7 @@ from models.audio.tts.unified_voice2 import ConditioningEncoder
from models.lucidrains.dalle.transformer import Transformer from models.lucidrains.dalle.transformer import Transformer
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import opt_get from utils.util import opt_get
import torch_intermediary as ml
def exists(val): def exists(val):
@ -45,7 +46,7 @@ class VoiceCondCLIP(nn.Module):
self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech) self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech, self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech,
depth=speech_enc_depth, heads=speech_heads, rotary_emb=False) depth=speech_enc_depth, heads=speech_heads, rotary_emb=False)
self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False) self.to_speech_latent = ml.Linear(dim_speech, dim_latent, bias=False)
self.temperature = nn.Parameter(torch.tensor(1.)) self.temperature = nn.Parameter(torch.tensor(1.))
self.voice_mask_percentage = voice_mask_percentage self.voice_mask_percentage = voice_mask_percentage

View File

@ -11,6 +11,7 @@ from models.audio.tts.unet_diffusion_tts7 import CheckpointedXTransformerEncoder
from models.lucidrains.dalle.transformer import Transformer from models.lucidrains.dalle.transformer import Transformer
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import opt_get from utils.util import opt_get
import torch_intermediary as ml
def exists(val): def exists(val):
@ -53,11 +54,13 @@ class VoiceCLIP(nn.Module):
distributed_collect=False, distributed_collect=False,
): ):
super().__init__() super().__init__()
self.text_emb = nn.Embedding(num_text_tokens, dim_text) # nn.Embedding
self.to_text_latent = nn.Linear(dim_text, dim_latent, bias=False) self.text_emb = ml.Embedding(num_text_tokens, dim_text)
self.to_text_latent = ml.Linear(dim_text, dim_latent, bias=False)
self.speech_emb = nn.Embedding(num_speech_tokens, dim_speech) # nn.Embedding
self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False) self.speech_emb = ml.Embedding(num_speech_tokens, dim_speech)
self.to_speech_latent = ml.Linear(dim_speech, dim_latent, bias=False)
if use_xformers: if use_xformers:
self.text_transformer = CheckpointedXTransformerEncoder( self.text_transformer = CheckpointedXTransformerEncoder(
@ -105,8 +108,10 @@ class VoiceCLIP(nn.Module):
self.min_mel_size = min_mel_size self.min_mel_size = min_mel_size
self.distributed_collect = distributed_collect self.distributed_collect = distributed_collect
if not use_xformers: if not use_xformers:
self.text_pos_emb = nn.Embedding(text_seq_len, dim_text) # nn.Embedding
self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech) self.text_pos_emb = ml.Embedding(text_seq_len, dim_text)
# nn.Embedding
self.speech_pos_emb = ml.Embedding(num_speech_tokens, dim_speech)
def embed_text(self, text): def embed_text(self, text):
text_mask = torch.ones_like(text.float()).bool() text_mask = torch.ones_like(text.float()).bool()

View File

@ -6,6 +6,7 @@ import math
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch_intermediary as ml
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
@ -36,7 +37,7 @@ def linear(*args, **kwargs):
""" """
Create a linear module. Create a linear module.
""" """
return nn.Linear(*args, **kwargs) return ml.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs): def avg_pool_nd(dims, *args, **kwargs):

View File

@ -6,6 +6,7 @@ from models.arch_util import ConvGnLelu, default_init_weights, make_layer
from models.diffusion.nn import timestep_embedding from models.diffusion.nn import timestep_embedding
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import checkpoint from utils.util import checkpoint
import torch_intermediary as ml
# Conditionally uses torch's checkpoint functionality if it is enabled in the opt file. # Conditionally uses torch's checkpoint functionality if it is enabled in the opt file.
@ -28,7 +29,7 @@ class ResidualDenseBlock(nn.Module):
self.first_conv = ConvGnLelu(mid_channels, mid_channels, activation=True, norm=False, bias=True) self.first_conv = ConvGnLelu(mid_channels, mid_channels, activation=True, norm=False, bias=True)
self.emb_layers = nn.Sequential( self.emb_layers = nn.Sequential(
nn.SiLU(), nn.SiLU(),
nn.Linear( ml.Linear(
mid_channels*4, mid_channels*4,
mid_channels, mid_channels,
), ),
@ -143,9 +144,9 @@ class RRDBNet(nn.Module):
# Guided diffusion uses a time embedding. # Guided diffusion uses a time embedding.
time_embed_dim = mid_channels * 4 time_embed_dim = mid_channels * 4
self.time_embed = nn.Sequential( self.time_embed = nn.Sequential(
nn.Linear(mid_channels, time_embed_dim), ml.Linear(mid_channels, time_embed_dim),
nn.SiLU(), nn.SiLU(),
nn.Linear(time_embed_dim, time_embed_dim), ml.Linear(time_embed_dim, time_embed_dim),
) )
self.body = make_layer( self.body = make_layer(

View File

@ -20,6 +20,7 @@ from models.diffusion.nn import (
) )
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import checkpoint from utils.util import checkpoint
import torch_intermediary as ml
class AttentionPool2d(nn.Module): class AttentionPool2d(nn.Module):
@ -515,7 +516,8 @@ class UNetModel(nn.Module):
) )
if self.num_classes is not None: if self.num_classes is not None:
self.label_emb = nn.Embedding(num_classes, time_embed_dim) # nn.Embedding
self.label_emb = ml.Embedding(num_classes, time_embed_dim)
self.use_raw_y_as_embedding = use_raw_y_as_embedding self.use_raw_y_as_embedding = use_raw_y_as_embedding
assert not ((self.num_classes is not None) and use_raw_y_as_embedding) # These are mutually-exclusive. assert not ((self.num_classes is not None) and use_raw_y_as_embedding) # These are mutually-exclusive.
@ -867,16 +869,16 @@ class EncoderUNetModel(nn.Module):
) )
elif pool == "spatial": elif pool == "spatial":
self.out = nn.Sequential( self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048), ml.Linear(self._feature_size, 2048),
nn.ReLU(), nn.ReLU(),
nn.Linear(2048, self.out_channels), ml.Linear(2048, self.out_channels),
) )
elif pool == "spatial_v2": elif pool == "spatial_v2":
self.out = nn.Sequential( self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048), ml.Linear(self._feature_size, 2048),
normalization(2048), normalization(2048),
nn.SiLU(), nn.SiLU(),
nn.Linear(2048, self.out_channels), ml.Linear(2048, self.out_channels),
) )
else: else:
raise NotImplementedError(f"Unexpected {pool} pooling") raise NotImplementedError(f"Unexpected {pool} pooling")

View File

@ -26,6 +26,7 @@ from models.diffusion.nn import (
) )
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import checkpoint from utils.util import checkpoint
import torch_intermediary as ml
class AttentionPool2d(nn.Module): class AttentionPool2d(nn.Module):
@ -476,7 +477,8 @@ class UNetModel(nn.Module):
) )
if self.num_classes is not None: if self.num_classes is not None:
self.label_emb = nn.Embedding(num_classes, time_embed_dim) # nn.Embedding
self.label_emb = ml.Embedding(num_classes, time_embed_dim)
self.input_blocks = nn.ModuleList( self.input_blocks = nn.ModuleList(
[ [
@ -736,7 +738,7 @@ class ResNetEncoder(nn.Module):
dilate=replace_stride_with_dilation[2]) dilate=replace_stride_with_dilation[2])
f=512 f=512
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(f * block.expansion, output_dim) self.fc = ml.Linear(f * block.expansion, output_dim)
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):

View File

@ -5,6 +5,7 @@ import torch.nn.functional as F
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import checkpoint, opt_get from utils.util import checkpoint, opt_get
import torch_intermediary as ml
class Discriminator_VGG_128(nn.Module): class Discriminator_VGG_128(nn.Module):
@ -46,8 +47,8 @@ class Discriminator_VGG_128(nn.Module):
input_img_factor = input_img_factor // 2 input_img_factor = input_img_factor // 2
final_nf = nf * 16 final_nf = nf * 16
self.linear1 = nn.Linear(final_nf * 4 * input_img_factor * 4 * input_img_factor, 100) self.linear1 = ml.Linear(final_nf * 4 * input_img_factor * 4 * input_img_factor, 100)
self.linear2 = nn.Linear(100, 1) self.linear2 = ml.Linear(100, 1)
# activation function # activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
@ -129,8 +130,8 @@ class Discriminator_VGG_128_GN(nn.Module):
# activation function # activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.linear1 = nn.Linear(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 100) self.linear1 = ml.Linear(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 100)
self.linear2 = nn.Linear(100, 1) self.linear2 = ml.Linear(100, 1)
def compute_body(self, x): def compute_body(self, x):
fea = self.lrelu(self.conv0_0(x)) fea = self.lrelu(self.conv0_0(x))
@ -219,8 +220,8 @@ class DiscriminatorVGG448GN(nn.Module):
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
final_nf = nf * 8 final_nf = nf * 8
self.linear1 = nn.Linear(int(final_nf * 7 * 7), 100) self.linear1 = ml.Linear(int(final_nf * 7 * 7), 100)
self.linear2 = nn.Linear(100, 1) self.linear2 = ml.Linear(100, 1)
# Assign all new heads to the new param group.2 # Assign all new heads to the new param group.2
for m in [self.convn1_0, self.convn1_1, self.bnn1_1, self.conv0_0_new, self.bn0_0]: for m in [self.convn1_0, self.convn1_1, self.bnn1_1, self.conv0_0_new, self.bn0_0]:

View File

@ -2,6 +2,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.init as init import torch.nn.init as init
import torch.nn.functional as F import torch.nn.functional as F
import torch_intermediary as ml
def initialize_weights(net_l, scale=1): def initialize_weights(net_l, scale=1):
@ -14,7 +15,7 @@ def initialize_weights(net_l, scale=1):
m.weight.data *= scale # for residual block m.weight.data *= scale # for residual block
if m.bias is not None: if m.bias is not None:
m.bias.data.zero_() m.bias.data.zero_()
elif isinstance(m, nn.Linear): elif isinstance(m, ml.Linear):
init.kaiming_normal_(m.weight, a=0, mode='fan_in') init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale m.weight.data *= scale
if m.bias is not None: if m.bias is not None:

View File

@ -28,6 +28,7 @@ except:
APEX_AVAILABLE = False APEX_AVAILABLE = False
assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.' assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.'
import torch_intermediary as ml
num_cores = multiprocessing.cpu_count() num_cores = multiprocessing.cpu_count()
@ -351,7 +352,7 @@ class RGBBlock(nn.Module):
def __init__(self, latent_dim, input_channel, upsample, rgba=False): def __init__(self, latent_dim, input_channel, upsample, rgba=False):
super().__init__() super().__init__()
self.input_channel = input_channel self.input_channel = input_channel
self.to_style = nn.Linear(latent_dim, input_channel) self.to_style = ml.Linear(latent_dim, input_channel)
out_filters = 3 if not rgba else 4 out_filters = 3 if not rgba else 4
self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False) self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False)
@ -489,16 +490,16 @@ class GeneratorBlockWithStructure(nn.Module):
# Uses stylegan1 style blocks for injecting structural latent. # Uses stylegan1 style blocks for injecting structural latent.
self.conv0 = EqualConv2d(input_channels, filters, 3, padding=1) self.conv0 = EqualConv2d(input_channels, filters, 3, padding=1)
self.to_noise0 = nn.Linear(1, filters) self.to_noise0 = ml.Linear(1, filters)
self.noise0 = equal_lr(NoiseInjection(filters)) self.noise0 = equal_lr(NoiseInjection(filters))
self.adain0 = AdaptiveInstanceNorm(filters, latent_dim) self.adain0 = AdaptiveInstanceNorm(filters, latent_dim)
self.to_style1 = nn.Linear(latent_dim, filters) self.to_style1 = ml.Linear(latent_dim, filters)
self.to_noise1 = nn.Linear(1, filters) self.to_noise1 = ml.Linear(1, filters)
self.conv1 = Conv2DMod(filters, filters, 3) self.conv1 = Conv2DMod(filters, filters, 3)
self.to_style2 = nn.Linear(latent_dim, filters) self.to_style2 = ml.Linear(latent_dim, filters)
self.to_noise2 = nn.Linear(1, filters) self.to_noise2 = ml.Linear(1, filters)
self.conv2 = Conv2DMod(filters, filters, 3) self.conv2 = Conv2DMod(filters, filters, 3)
self.activation = leaky_relu() self.activation = leaky_relu()
@ -540,12 +541,12 @@ class GeneratorBlock(nn.Module):
self.structure_conv = nn.Conv2d(3, input_channels, 3, padding=1) self.structure_conv = nn.Conv2d(3, input_channels, 3, padding=1)
input_channels = input_channels * 2 input_channels = input_channels * 2
self.to_style1 = nn.Linear(latent_dim, input_channels) self.to_style1 = ml.Linear(latent_dim, input_channels)
self.to_noise1 = nn.Linear(1, filters) self.to_noise1 = ml.Linear(1, filters)
self.conv1 = Conv2DMod(input_channels, filters, 3) self.conv1 = Conv2DMod(input_channels, filters, 3)
self.to_style2 = nn.Linear(latent_dim, filters) self.to_style2 = ml.Linear(latent_dim, filters)
self.to_noise2 = nn.Linear(1, filters) self.to_noise2 = ml.Linear(1, filters)
self.conv2 = Conv2DMod(filters, filters, 3) self.conv2 = Conv2DMod(filters, filters, 3)
self.activation = leaky_relu() self.activation = leaky_relu()
@ -724,7 +725,7 @@ class StyleGan2GeneratorWithLatent(nn.Module):
def _init_weights(self): def _init_weights(self):
for m in self.modules(): for m in self.modules():
if type(m) in {nn.Conv2d, nn.Linear} and hasattr(m, 'weight'): if type(m) in {nn.Conv2d, ml.Linear} and hasattr(m, 'weight'):
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
for block in self.gen.blocks: for block in self.gen.blocks:
@ -804,7 +805,7 @@ class StyleGan2Discriminator(nn.Module):
self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1) self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1)
self.flatten = Flatten() self.flatten = Flatten()
self.to_logit = nn.Linear(latent_dim, 1) self.to_logit = ml.Linear(latent_dim, 1)
self._init_weights() self._init_weights()
@ -836,7 +837,7 @@ class StyleGan2Discriminator(nn.Module):
def _init_weights(self): def _init_weights(self):
for m in self.modules(): for m in self.modules():
if type(m) in {nn.Conv2d, nn.Linear}: if type(m) in {nn.Conv2d, ml.Linear}:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')

View File

@ -12,6 +12,7 @@ from torch import nn
from data.images.byol_attachment import RandomApply from data.images.byol_attachment import RandomApply
from trainer.networks import register_model, create_model from trainer.networks import register_model, create_model
from utils.util import checkpoint, opt_get from utils.util import checkpoint, opt_get
import torch_intermediary as ml
def default(val, def_val): def default(val, def_val):
@ -78,10 +79,10 @@ class MLP(nn.Module):
def __init__(self, dim, projection_size, hidden_size=4096): def __init__(self, dim, projection_size, hidden_size=4096):
super().__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
nn.Linear(dim, hidden_size), ml.Linear(dim, hidden_size),
nn.BatchNorm1d(hidden_size), nn.BatchNorm1d(hidden_size),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Linear(hidden_size, projection_size) ml.Linear(hidden_size, projection_size)
) )
def forward(self, x): def forward(self, x):
@ -103,10 +104,10 @@ class StructuralMLP(nn.Module):
nn.BatchNorm2d(c), nn.BatchNorm2d(c),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Flatten(), nn.Flatten(),
nn.Linear(flattened_dim, hidden_size), ml.Linear(flattened_dim, hidden_size),
nn.BatchNorm1d(hidden_size), nn.BatchNorm1d(hidden_size),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Linear(hidden_size, projection_size) ml.Linear(hidden_size, projection_size)
) )
def forward(self, x): def forward(self, x):

View File

@ -1,6 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
import torch_intermediary as ml
__all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152'] __all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152']
@ -108,8 +109,8 @@ class FixupResNet(nn.Module):
self.layer4 = self._make_layer(block, num_filters*8, layers[3], stride=2) self.layer4 = self._make_layer(block, num_filters*8, layers[3], stride=2)
self.bias2 = nn.Parameter(torch.zeros(1)) self.bias2 = nn.Parameter(torch.zeros(1))
reduced_img_sz = int(input_img_size / 32) reduced_img_sz = int(input_img_size / 32)
self.fc1 = nn.Linear(num_filters * 8 * reduced_img_sz * reduced_img_sz, 100) self.fc1 = ml.Linear(num_filters * 8 * reduced_img_sz * reduced_img_sz, 100)
self.fc2 = nn.Linear(100, num_classes) self.fc2 = ml.Linear(100, num_classes)
for m in self.modules(): for m in self.modules():
if isinstance(m, FixupBasicBlock): if isinstance(m, FixupBasicBlock):
@ -124,7 +125,7 @@ class FixupResNet(nn.Module):
if m.downsample is not None: if m.downsample is not None:
nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:])))) nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:]))))
''' '''
elif isinstance(m, nn.Linear): elif isinstance(m, ml.Linear):
nn.init.constant_(m.weight, 0) nn.init.constant_(m.weight, 0)
nn.init.constant_(m.bias, 0)''' nn.init.constant_(m.bias, 0)'''

View File

@ -5,6 +5,7 @@ import torch.nn.functional as F
from models.arch_util import ResBlock from models.arch_util import ResBlock
from models.lucidrains.x_transformers import Encoder from models.lucidrains.x_transformers import Encoder
from trainer.networks import register_model from trainer.networks import register_model
import torch_intermediary as ml
class VitLatent(nn.Module): class VitLatent(nn.Module):
@ -31,10 +32,10 @@ class VitLatent(nn.Module):
do_checkpointing=True do_checkpointing=True
) )
self.mlp = nn.Sequential(nn.Linear(hidden_dim, hidden_dim*2), self.mlp = nn.Sequential(ml.Linear(hidden_dim, hidden_dim*2),
nn.BatchNorm1d(hidden_dim*2), nn.BatchNorm1d(hidden_dim*2),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Linear(hidden_dim*2, hidden_dim)) ml.Linear(hidden_dim*2, hidden_dim))
def provide_ema(self, ema): def provide_ema(self, ema):
self.ema = ema self.ema = ema

View File

@ -7,6 +7,7 @@ import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from rotary_embedding_torch import apply_rotary_emb from rotary_embedding_torch import apply_rotary_emb
import torch_intermediary as ml
# helpers # helpers
@ -47,9 +48,9 @@ class Attention(nn.Module):
self.stable = stable self.stable = stable
self.causal = causal self.causal = causal
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_qkv = ml.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential( self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim), ml.Linear(inner_dim, dim),
nn.Dropout(dropout) nn.Dropout(dropout)
) )
@ -102,10 +103,10 @@ class SparseConvCausalAttention(nn.Module):
self.stable = stable self.stable = stable
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_qkv = ml.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential( self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim), ml.Linear(inner_dim, dim),
nn.Dropout(dropout) nn.Dropout(dropout)
) )
@ -222,10 +223,10 @@ class SparseAxialCausalAttention(nn.Module):
self.stable = stable self.stable = stable
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_qkv = ml.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential( self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim), ml.Linear(inner_dim, dim),
nn.Dropout(dropout) nn.Dropout(dropout)
) )

View File

@ -11,6 +11,7 @@ from models.lucidrains.dalle.attention import Attention, SparseAttention, Sparse
from rotary_embedding_torch import RotaryEmbedding, broadcat from rotary_embedding_torch import RotaryEmbedding, broadcat
from g_mlp_pytorch import gMLPBlock from g_mlp_pytorch import gMLPBlock
import torch_intermediary as ml
# helpers # helpers
@ -78,10 +79,10 @@ class FeedForward(nn.Module):
def __init__(self, dim, dropout = 0., mult = 4.): def __init__(self, dim, dropout = 0., mult = 4.):
super().__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
nn.Linear(dim, dim * mult * 2), ml.Linear(dim, dim * mult * 2),
GEGLU(), GEGLU(),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(dim * mult, dim) ml.Linear(dim * mult, dim)
) )
def forward(self, x): def forward(self, x):

View File

@ -21,6 +21,7 @@ try:
APEX_AVAILABLE = True APEX_AVAILABLE = True
except: except:
APEX_AVAILABLE = False APEX_AVAILABLE = False
import torch_intermediary as ml
# helpers # helpers
@ -356,10 +357,10 @@ class FeedForward(nn.Module):
activation = default(activation, nn.GELU) activation = default(activation, nn.GELU)
self.glu = glu self.glu = glu
self.w1 = nn.Linear(dim, dim * mult * (2 if glu else 1)) self.w1 = ml.Linear(dim, dim * mult * (2 if glu else 1))
self.act = activation() self.act = activation()
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.w2 = nn.Linear(dim * mult, dim) self.w2 = ml.Linear(dim * mult, dim)
def forward(self, x, **kwargs): def forward(self, x, **kwargs):
if not self.glu: if not self.glu:
@ -401,10 +402,10 @@ class Attention(nn.Module):
self.global_heads = heads - local_heads self.global_heads = heads - local_heads
self.local_attn = LocalAttention(window_size = local_window_size, causal = causal, autopad = True, dropout = dropout, look_forward = int(not causal), rel_pos_emb_config = (dim_head, local_heads)) if local_heads > 0 else None self.local_attn = LocalAttention(window_size = local_window_size, causal = causal, autopad = True, dropout = dropout, look_forward = int(not causal), rel_pos_emb_config = (dim_head, local_heads)) if local_heads > 0 else None
self.to_q = nn.Linear(dim, inner_dim, bias = qkv_bias) self.to_q = ml.Linear(dim, inner_dim, bias = qkv_bias)
self.to_k = nn.Linear(dim, inner_dim, bias = qkv_bias) self.to_k = ml.Linear(dim, inner_dim, bias = qkv_bias)
self.to_v = nn.Linear(dim, inner_dim, bias = qkv_bias) self.to_v = ml.Linear(dim, inner_dim, bias = qkv_bias)
self.to_out = nn.Linear(inner_dim, dim, bias = attn_out_bias) self.to_out = ml.Linear(inner_dim, dim, bias = attn_out_bias)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
def forward(self, x, pos_emb = None, context = None, mask = None, context_mask = None, **kwargs): def forward(self, x, pos_emb = None, context = None, mask = None, context_mask = None, **kwargs):
@ -458,7 +459,8 @@ class CrossAttention(Attention):
class AbsolutePositionalEmbedding(nn.Module): class AbsolutePositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len): def __init__(self, dim, max_seq_len):
super().__init__() super().__init__()
self.emb = nn.Embedding(max_seq_len, dim) # nn.Embedding
self.emb = ml.Embedding(max_seq_len, dim)
def forward(self, x): def forward(self, x):
t = torch.arange(x.shape[1], device=x.device) t = torch.arange(x.shape[1], device=x.device)
@ -619,7 +621,8 @@ class PerformerLM(nn.Module):
local_attn_heads = cast_tuple(local_attn_heads) local_attn_heads = cast_tuple(local_attn_heads)
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.token_emb = nn.Embedding(num_tokens, dim) # nn.Embedding
self.token_emb = ml.Embedding(num_tokens, dim)
if rotary_position_emb: if rotary_position_emb:
self.pos_emb = FixedPositionalEmbedding(dim, max_seq_len) self.pos_emb = FixedPositionalEmbedding(dim, max_seq_len)
@ -636,7 +639,7 @@ class PerformerLM(nn.Module):
self.performer = Performer(dim, depth, heads, dim_head, local_attn_heads, local_window_size, causal, ff_mult, nb_features, feature_redraw_interval, reversible, ff_chunks, generalized_attention, kernel_fn, use_scalenorm, use_rezero, ff_glu, ff_dropout, attn_dropout, cross_attend, no_projection, auto_check_redraw, qkv_bias, attn_out_bias, shift_tokens) self.performer = Performer(dim, depth, heads, dim_head, local_attn_heads, local_window_size, causal, ff_mult, nb_features, feature_redraw_interval, reversible, ff_chunks, generalized_attention, kernel_fn, use_scalenorm, use_rezero, ff_glu, ff_dropout, attn_dropout, cross_attend, no_projection, auto_check_redraw, qkv_bias, attn_out_bias, shift_tokens)
self.norm = nn.LayerNorm(dim) self.norm = nn.LayerNorm(dim)
self.to_out = nn.Linear(dim, num_tokens) if not tie_embed else None self.to_out = ml.Linear(dim, num_tokens) if not tie_embed else None
def check_redraw_projections(self): def check_redraw_projections(self):
self.performer.check_redraw_projections() self.performer.check_redraw_projections()

View File

@ -8,6 +8,7 @@ from torch.cuda.amp import autocast
from einops import rearrange, repeat from einops import rearrange, repeat
from contextlib import contextmanager from contextlib import contextmanager
import torch_intermediary as ml
def par(t, nm): def par(t, nm):
@ -355,9 +356,9 @@ class VectorQuantize(nn.Module):
codebook_dim = default(codebook_dim, dim) codebook_dim = default(codebook_dim, dim)
requires_projection = codebook_dim != dim requires_projection = codebook_dim != dim
self.project_in = nn.Linear(dim, codebook_dim) if requires_projection \ self.project_in = ml.Linear(dim, codebook_dim) if requires_projection \
else nn.Identity() else nn.Identity()
self.project_out = nn.Linear(codebook_dim, dim) if requires_projection \ self.project_out = ml.Linear(codebook_dim, dim) if requires_projection \
else nn.Identity() else nn.Identity()
self.eps = eps self.eps = eps

View File

@ -11,6 +11,7 @@ from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange from einops.layers.torch import Rearrange
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
import torch_intermediary as ml
DEFAULT_DIM_HEAD = 64 DEFAULT_DIM_HEAD = 64
@ -125,7 +126,8 @@ class AbsolutePositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len): def __init__(self, dim, max_seq_len):
super().__init__() super().__init__()
self.scale = dim ** -0.5 self.scale = dim ** -0.5
self.emb = nn.Embedding(max_seq_len, dim) # nn.Embedding
self.emb = ml.Embedding(max_seq_len, dim)
def forward(self, x): def forward(self, x):
n = torch.arange(x.shape[1], device=x.device) n = torch.arange(x.shape[1], device=x.device)
@ -154,7 +156,8 @@ class RelativePositionBias(nn.Module):
self.causal = causal self.causal = causal
self.num_buckets = num_buckets self.num_buckets = num_buckets
self.max_distance = max_distance self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, heads) # nn.Embedding
self.relative_attention_bias = ml.Embedding(num_buckets, heads)
@staticmethod @staticmethod
def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128): def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
@ -360,7 +363,7 @@ class RMSScaleShiftNorm(nn.Module):
self.cdim = 1 self.cdim = 1
self.pdim = -1 self.pdim = -1
else: else:
self.scale_shift_process = nn.Linear(embed_dim, dim * 2, bias=bias) self.scale_shift_process = ml.Linear(embed_dim, dim * 2, bias=bias)
self.cdim = -1 self.cdim = -1
self.pdim = 1 self.pdim = 1
@ -447,7 +450,7 @@ class GLU(nn.Module):
def __init__(self, dim_in, dim_out, activation): def __init__(self, dim_in, dim_out, activation):
super().__init__() super().__init__()
self.act = activation self.act = activation
self.proj = nn.Linear(dim_in, dim_out * 2) self.proj = ml.Linear(dim_in, dim_out * 2)
def forward(self, x): def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1) x, gate = self.proj(x).chunk(2, dim=-1)
@ -472,7 +475,7 @@ class FeedForward(nn.Module):
activation = ReluSquared() if relu_squared else nn.GELU() activation = ReluSquared() if relu_squared else nn.GELU()
project_in = nn.Sequential( project_in = nn.Sequential(
nn.Linear(dim, inner_dim), ml.Linear(dim, inner_dim),
activation activation
) if not glu else GLU(dim, inner_dim, activation) ) if not glu else GLU(dim, inner_dim, activation)
@ -480,7 +483,7 @@ class FeedForward(nn.Module):
project_in, project_in,
nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(), nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out) ml.Linear(inner_dim, dim_out)
) )
# init last linear layer to 0 # init last linear layer to 0
@ -535,16 +538,16 @@ class Attention(nn.Module):
qk_dim = int(collab_compression * qk_dim) qk_dim = int(collab_compression * qk_dim)
self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim)) self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim))
self.to_q = nn.Linear(dim, qk_dim, bias=False) self.to_q = ml.Linear(dim, qk_dim, bias=False)
self.to_k = nn.Linear(dim, qk_dim, bias=False) self.to_k = ml.Linear(dim, qk_dim, bias=False)
self.to_v = nn.Linear(dim, v_dim, bias=False) self.to_v = ml.Linear(dim, v_dim, bias=False)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
# add GLU gating for aggregated values, from alphafold2 # add GLU gating for aggregated values, from alphafold2
self.to_v_gate = None self.to_v_gate = None
if gate_values: if gate_values:
self.to_v_gate = nn.Linear(dim, v_dim) self.to_v_gate = ml.Linear(dim, v_dim)
nn.init.constant_(self.to_v_gate.weight, 0) nn.init.constant_(self.to_v_gate.weight, 0)
nn.init.constant_(self.to_v_gate.bias, 1) nn.init.constant_(self.to_v_gate.bias, 1)
@ -581,7 +584,7 @@ class Attention(nn.Module):
# attention on attention # attention on attention
self.attn_on_attn = on_attn self.attn_on_attn = on_attn
out_dim = default(out_dim, dim) out_dim = default(out_dim, dim)
self.to_out = nn.Sequential(nn.Linear(v_dim, out_dim * 2), nn.GLU()) if on_attn else nn.Linear(v_dim, out_dim) self.to_out = nn.Sequential(ml.Linear(v_dim, out_dim * 2), nn.GLU()) if on_attn else ml.Linear(v_dim, out_dim)
self.rel_pos_bias = rel_pos_bias self.rel_pos_bias = rel_pos_bias
if rel_pos_bias: if rel_pos_bias:
@ -1077,7 +1080,7 @@ class ViTransformerWrapper(nn.Module):
self.patch_size = patch_size self.patch_size = patch_size
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_to_embedding = nn.Linear(patch_dim, dim) self.patch_to_embedding = ml.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout) self.dropout = nn.Dropout(emb_dropout)
@ -1135,18 +1138,19 @@ class TransformerWrapper(nn.Module):
self.max_mem_len = max_mem_len self.max_mem_len = max_mem_len
self.shift_mem_down = shift_mem_down self.shift_mem_down = shift_mem_down
self.token_emb = nn.Embedding(num_tokens, emb_dim) # nn.Embedding
self.token_emb = ml.Embedding(num_tokens, emb_dim)
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
use_pos_emb and not attn_layers.has_pos_emb) else always(0) use_pos_emb and not attn_layers.has_pos_emb) else always(0)
self.emb_dropout = nn.Dropout(emb_dropout) self.emb_dropout = nn.Dropout(emb_dropout)
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() self.project_emb = ml.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
self.attn_layers = attn_layers self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim) self.norm = nn.LayerNorm(dim)
self.init_() self.init_()
self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() self.to_logits = ml.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
# memory tokens (like [cls]) from Memory Transformers paper # memory tokens (like [cls]) from Memory Transformers paper
num_memory_tokens = default(num_memory_tokens, 0) num_memory_tokens = default(num_memory_tokens, 0)
@ -1233,12 +1237,12 @@ class ContinuousTransformerWrapper(nn.Module):
use_pos_emb and not attn_layers.has_pos_emb) else always(0) use_pos_emb and not attn_layers.has_pos_emb) else always(0)
self.emb_dropout = nn.Dropout(emb_dropout) self.emb_dropout = nn.Dropout(emb_dropout)
self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity() self.project_in = ml.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
self.attn_layers = attn_layers self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim) self.norm = nn.LayerNorm(dim)
self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity() self.project_out = ml.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
def forward( def forward(
self, self,

View File

@ -4,13 +4,15 @@ import torch.nn.functional as F
from torch import einsum from torch import einsum
from utils.weight_scheduler import LinearDecayWeightScheduler from utils.weight_scheduler import LinearDecayWeightScheduler
import torch_intermediary as ml
class GumbelQuantizer(nn.Module): class GumbelQuantizer(nn.Module):
def __init__(self, inp_dim, codebook_dim, num_tokens, straight_through=False): def __init__(self, inp_dim, codebook_dim, num_tokens, straight_through=False):
super().__init__() super().__init__()
self.to_logits = nn.Conv1d(inp_dim, num_tokens, 1) self.to_logits = nn.Conv1d(inp_dim, num_tokens, 1)
self.codebook = nn.Embedding(num_tokens, codebook_dim) # nn.Embedding
self.codebook = ml.Embedding(num_tokens, codebook_dim)
self.straight_through = straight_through self.straight_through = straight_through
self.temperature_scheduler = LinearDecayWeightScheduler(10, 5000, .9, 2000) self.temperature_scheduler = LinearDecayWeightScheduler(10, 5000, .9, 2000)
self.step = 0 self.step = 0

View File

@ -4,6 +4,7 @@ import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from models.arch_util import l2norm, sample_vectors, default, ema_inplace from models.arch_util import l2norm, sample_vectors, default, ema_inplace
import torch_intermediary as ml
def kmeans(samples, num_clusters, num_iters = 10, use_cosine_sim = False): def kmeans(samples, num_clusters, num_iters = 10, use_cosine_sim = False):
@ -184,8 +185,8 @@ class VectorQuantize(nn.Module):
codebook_dim = default(codebook_dim, dim) codebook_dim = default(codebook_dim, dim)
requires_projection = codebook_dim != dim requires_projection = codebook_dim != dim
self.project_in = nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity() self.project_in = ml.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
self.project_out = nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity() self.project_out = ml.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
self.eps = eps self.eps = eps

View File

@ -0,0 +1,41 @@
"""
from bitsandbytes.nn import Linear8bitLt as Linear
from bitsandbytes.nn import StableEmbedding as Embedding
from bitsandbytes.optim.adam import Adam8bit as Adam
from bitsandbytes.optim.adamw import AdamW8bit as AdamW
"""
"""
from torch.nn import Linear
from torch.nn import Embedding
from torch.optim.adam import Adam
from torch.optim.adamw import AdamW
"""
OVERRIDE_LINEAR = False
OVERRIDE_EMBEDDING = False
OVERRIDE_ADAM = True
OVERRIDE_ADAMW = True
USE_STABLE_EMBEDDING = True
if OVERRIDE_LINEAR:
from bitsandbytes.nn import Linear8bitLt as Linear
else:
from torch.nn import Linear
if OVERRIDE_EMBEDDING:
if USE_STABLE_EMBEDDING:
from bitsandbytes.nn import StableEmbedding as Embedding
else:
from bitsandbytes.nn import Embedding as Embedding
else:
from torch.nn import Embedding
if OVERRIDE_ADAM:
from bitsandbytes.optim.adam import Adam8bit as Adam
else:
from torch.optim.adam import Adam
if OVERRIDE_ADAMW:
from bitsandbytes.optim.adamw import AdamW8bit as AdamW
else:
from torch.optim.adamw import AdamW

View File

@ -21,6 +21,7 @@ import torchvision.utils as utils
from utils.loss_accumulator import LossAccumulator, InfStorageLossAccumulator from utils.loss_accumulator import LossAccumulator, InfStorageLossAccumulator
from utils.util import opt_get, denormalize from utils.util import opt_get, denormalize
import torch_intermediary as ml
logger = logging.getLogger('base') logger = logging.getLogger('base')
@ -337,7 +338,7 @@ class ExtensibleTrainer(BaseModel):
for net in self.networks.values(): for net in self.networks.values():
for mod in net.modules(): for mod in net.modules():
fan_in = -1 fan_in = -1
if isinstance(mod, nn.Linear): if isinstance(mod, ml.Linear):
fan_in = mod.weight.data.shape[1] fan_in = mod.weight.data.shape[1]
elif isinstance(mod, nn.Conv1d): elif isinstance(mod, nn.Conv1d):
fan_in = mod.weight.data.shape[0] fan_in = mod.weight.data.shape[0]

View File

@ -1,3 +1,4 @@
import logging import logging
from collections import OrderedDict from collections import OrderedDict
@ -6,6 +7,7 @@ import torch.nn as nn
import trainer.networks as networks import trainer.networks as networks
import trainer.lr_scheduler as lr_scheduler import trainer.lr_scheduler as lr_scheduler
from .base_model import BaseModel from .base_model import BaseModel
import torch_intermediary as ml
logger = logging.getLogger('base') logger = logging.getLogger('base')
@ -40,7 +42,8 @@ class FeatureModel(BaseModel):
else: else:
if self.rank <= 0: if self.rank <= 0:
logger.warning('Params [{:s}] will not optimize.'.format(k)) logger.warning('Params [{:s}] will not optimize.'.format(k))
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], # torch.optim.Adam
self.optimizer_G = ml.Adam(optim_params, lr=train_opt['lr_G'],
weight_decay=wd_G, weight_decay=wd_G,
betas=(train_opt['beta1_G'], train_opt['beta2_G'])) betas=(train_opt['beta1_G'], train_opt['beta2_G']))
self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_G)

View File

@ -3,10 +3,10 @@ from collections import Counter
from collections import defaultdict from collections import defaultdict
import torch import torch
from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import _LRScheduler
import torch_intermediary as ml
from utils.util import opt_get from utils.util import opt_get
def get_scheduler_for_name(name, optimizers, scheduler_opt): def get_scheduler_for_name(name, optimizers, scheduler_opt):
schedulers = [] schedulers = []
for o in optimizers: for o in optimizers:
@ -136,7 +136,8 @@ class CosineAnnealingLR_Restart(_LRScheduler):
if __name__ == "__main__": if __name__ == "__main__":
optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=1e-4, weight_decay=0, #torch.optim.Adam
optimizer = ml.Adam([torch.zeros(3, 64, 3, 3)], lr=1e-4, weight_decay=0,
betas=(0.9, 0.99)) betas=(0.9, 0.99))
############################## ##############################
# MultiStepLR_Restart # MultiStepLR_Restart

View File

@ -12,6 +12,7 @@ from utils.util import recursively_detach, opt_get, clip_grad_norm
logger = logging.getLogger('base') logger = logging.getLogger('base')
import torch_intermediary as ml
# Defines the expected API for a single training step # Defines the expected API for a single training step
class ConfigurableStep(Module): class ConfigurableStep(Module):
@ -82,7 +83,8 @@ class ConfigurableStep(Module):
import torch.nn as nn import torch.nn as nn
norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d, norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d,
nn.BatchNorm3d, nn.InstanceNorm3d, nn.GroupNorm, nn.LayerNorm) nn.BatchNorm3d, nn.InstanceNorm3d, nn.GroupNorm, nn.LayerNorm)
emb_modules = (nn.Embedding, nn.EmbeddingBag) # nn.Embedding
emb_modules = (ml.Embedding, nn.EmbeddingBag)
param_names_notweights = set() param_names_notweights = set()
all_param_names = set() all_param_names = set()
param_map = {} param_map = {}
@ -123,7 +125,8 @@ class ConfigurableStep(Module):
{ 'params': params_weights, 'weight_decay': opt_get(opt_config, ['weight_decay'], 0) }, { 'params': params_weights, 'weight_decay': opt_get(opt_config, ['weight_decay'], 0) },
{ 'params': params_notweights, 'weight_decay': 0 } { 'params': params_notweights, 'weight_decay': 0 }
] ]
opt = torch.optim.AdamW(groups, lr=opt_config['lr'], # torch.optim.AdamW
opt = ml.AdamW(groups, lr=opt_config['lr'],
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2), weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999))) betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
opt._group_names = [params_names_weights, params_names_notweights] opt._group_names = [params_names_weights, params_names_notweights]
@ -141,14 +144,16 @@ class ConfigurableStep(Module):
# The torch ZeRO implementation does not seem to support parameter groups, so do not shard the non-weighted # The torch ZeRO implementation does not seem to support parameter groups, so do not shard the non-weighted
# parameters and just use a normal AdamW implementation. In a large network, these weights will normally # parameters and just use a normal AdamW implementation. In a large network, these weights will normally
# be a tiny fraction of the total weights. # be a tiny fraction of the total weights.
opt_unweighted = torch.optim.AdamW(params_notweights, lr=opt_config['lr'], weight_decay=0, # torch.optim.AdamW
opt_unweighted = ml.AdamW(params_notweights, lr=opt_config['lr'], weight_decay=0,
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999))) betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
opt_unweighted._config = opt_config opt_unweighted._config = opt_config
opt_unweighted._config['network'] = net_name opt_unweighted._config['network'] = net_name
opt_unweighted._group_names = [] opt_unweighted._group_names = []
self.optimizers.append(opt_unweighted) self.optimizers.append(opt_unweighted)
opt = ZeroRedundancyOptimizer(params_weights, optimizer_class=torch.optim.AdamW, lr=opt_config['lr'], # torch.optim.AdamW
opt = ZeroRedundancyOptimizer(params_weights, optimizer_class=ml.AdamW, lr=opt_config['lr'],
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2), weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999))) betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
opt.param_groups[0]['initial_lr'] = opt_config['lr'] opt.param_groups[0]['initial_lr'] = opt_config['lr']
@ -162,7 +167,8 @@ class ConfigurableStep(Module):
opt._group_names = sorted(list(all_param_names)) opt._group_names = sorted(list(all_param_names))
elif self.step_opt['optimizer'] == 'lamb': elif self.step_opt['optimizer'] == 'lamb':
from trainer.optimizers.lamb import Lamb from trainer.optimizers.lamb import Lamb
opt_unweighted = torch.optim.AdamW(params_notweights, lr=opt_config['lr'], weight_decay=0, # torch.optim.AdamW
opt_unweighted = ml.AdamW(params_notweights, lr=opt_config['lr'], weight_decay=0,
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999))) betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
opt_unweighted._config = opt_config opt_unweighted._config = opt_config
opt_unweighted._config['network'] = net_name opt_unweighted._config['network'] = net_name

View File

@ -45,4 +45,7 @@ rotary-embedding-torch
axial_positional_embedding axial_positional_embedding
g-mlp-pytorch g-mlp-pytorch
x-clip x-clip
x_transformers==1.0.4 x_transformers==1.0.4
# bitsandbytes
bitsandbytes==0.35.0

View File

@ -67,6 +67,8 @@ setuptools.setup(
"g-mlp-pytorch", "g-mlp-pytorch",
"x-clip", "x-clip",
"x_transformers==1.0.4", "x_transformers==1.0.4",
"bitsandbytes==0.35.0",
], ],
classifiers=[ classifiers=[
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",