forked from mrq/DL-Art-School
Merge pull request 'bitsandbytes' (#2) from bitsandbytes into master
Reviewed-on: mrq/DL-Art-School#2
This commit is contained in:
commit
918473807f
54
bitsandbytes_windows/cextension.py
Normal file
54
bitsandbytes_windows/cextension.py
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
import ctypes as ct
|
||||||
|
from pathlib import Path
|
||||||
|
from warnings import warn
|
||||||
|
|
||||||
|
from .cuda_setup.main import evaluate_cuda_setup
|
||||||
|
|
||||||
|
|
||||||
|
class CUDALibrary_Singleton(object):
|
||||||
|
_instance = None
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
raise RuntimeError("Call get_instance() instead")
|
||||||
|
|
||||||
|
def initialize(self):
|
||||||
|
binary_name = evaluate_cuda_setup()
|
||||||
|
package_dir = Path(__file__).parent
|
||||||
|
binary_path = package_dir / binary_name
|
||||||
|
|
||||||
|
if not binary_path.exists():
|
||||||
|
print(f"CUDA SETUP: TODO: compile library for specific version: {binary_name}")
|
||||||
|
legacy_binary_name = "libbitsandbytes.so"
|
||||||
|
print(f"CUDA SETUP: Defaulting to {legacy_binary_name}...")
|
||||||
|
binary_path = package_dir / legacy_binary_name
|
||||||
|
if not binary_path.exists():
|
||||||
|
print('CUDA SETUP: CUDA detection failed. Either CUDA driver not installed, CUDA not installed, or you have multiple conflicting CUDA libraries!')
|
||||||
|
print('CUDA SETUP: If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION` for example, `make CUDA_VERSION=113`.')
|
||||||
|
raise Exception('CUDA SETUP: Setup Failed!')
|
||||||
|
# self.lib = ct.cdll.LoadLibrary(binary_path)
|
||||||
|
self.lib = ct.cdll.LoadLibrary(str(binary_path)) # $$$
|
||||||
|
else:
|
||||||
|
print(f"CUDA SETUP: Loading binary {binary_path}...")
|
||||||
|
# self.lib = ct.cdll.LoadLibrary(binary_path)
|
||||||
|
self.lib = ct.cdll.LoadLibrary(str(binary_path)) # $$$
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = cls.__new__(cls)
|
||||||
|
cls._instance.initialize()
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
|
||||||
|
lib = CUDALibrary_Singleton.get_instance().lib
|
||||||
|
try:
|
||||||
|
lib.cadam32bit_g32
|
||||||
|
lib.get_context.restype = ct.c_void_p
|
||||||
|
lib.get_cusparse.restype = ct.c_void_p
|
||||||
|
COMPILED_WITH_CUDA = True
|
||||||
|
except AttributeError:
|
||||||
|
warn(
|
||||||
|
"The installed version of bitsandbytes was compiled without GPU support. "
|
||||||
|
"8-bit optimizers and GPU quantization are unavailable."
|
||||||
|
)
|
||||||
|
COMPILED_WITH_CUDA = False
|
166
bitsandbytes_windows/cuda_setup/main.py
Normal file
166
bitsandbytes_windows/cuda_setup/main.py
Normal file
|
@ -0,0 +1,166 @@
|
||||||
|
"""
|
||||||
|
extract factors the build is dependent on:
|
||||||
|
[X] compute capability
|
||||||
|
[ ] TODO: Q - What if we have multiple GPUs of different makes?
|
||||||
|
- CUDA version
|
||||||
|
- Software:
|
||||||
|
- CPU-only: only CPU quantization functions (no optimizer, no matrix multiple)
|
||||||
|
- CuBLAS-LT: full-build 8-bit optimizer
|
||||||
|
- no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`)
|
||||||
|
|
||||||
|
evaluation:
|
||||||
|
- if paths faulty, return meaningful error
|
||||||
|
- else:
|
||||||
|
- determine CUDA version
|
||||||
|
- determine capabilities
|
||||||
|
- based on that set the default path
|
||||||
|
"""
|
||||||
|
|
||||||
|
import ctypes
|
||||||
|
|
||||||
|
from .paths import determine_cuda_runtime_lib_path
|
||||||
|
|
||||||
|
|
||||||
|
def check_cuda_result(cuda, result_val):
|
||||||
|
# 3. Check for CUDA errors
|
||||||
|
if result_val != 0:
|
||||||
|
error_str = ctypes.c_char_p()
|
||||||
|
cuda.cuGetErrorString(result_val, ctypes.byref(error_str))
|
||||||
|
print(f"CUDA exception! Error code: {error_str.value.decode()}")
|
||||||
|
|
||||||
|
def get_cuda_version(cuda, cudart_path):
|
||||||
|
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION
|
||||||
|
try:
|
||||||
|
cudart = ctypes.CDLL(cudart_path)
|
||||||
|
except OSError:
|
||||||
|
# TODO: shouldn't we error or at least warn here?
|
||||||
|
print(f'ERROR: libcudart.so could not be read from path: {cudart_path}!')
|
||||||
|
return None
|
||||||
|
|
||||||
|
version = ctypes.c_int()
|
||||||
|
check_cuda_result(cuda, cudart.cudaRuntimeGetVersion(ctypes.byref(version)))
|
||||||
|
version = int(version.value)
|
||||||
|
major = version//1000
|
||||||
|
minor = (version-(major*1000))//10
|
||||||
|
|
||||||
|
if major < 11:
|
||||||
|
print('CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!')
|
||||||
|
|
||||||
|
return f'{major}{minor}'
|
||||||
|
|
||||||
|
|
||||||
|
def get_cuda_lib_handle():
|
||||||
|
# 1. find libcuda.so library (GPU driver) (/usr/lib)
|
||||||
|
try:
|
||||||
|
cuda = ctypes.CDLL("libcuda.so")
|
||||||
|
except OSError:
|
||||||
|
# TODO: shouldn't we error or at least warn here?
|
||||||
|
print('CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!')
|
||||||
|
return None
|
||||||
|
check_cuda_result(cuda, cuda.cuInit(0))
|
||||||
|
|
||||||
|
return cuda
|
||||||
|
|
||||||
|
|
||||||
|
def get_compute_capabilities(cuda):
|
||||||
|
"""
|
||||||
|
1. find libcuda.so library (GPU driver) (/usr/lib)
|
||||||
|
init_device -> init variables -> call function by reference
|
||||||
|
2. call extern C function to determine CC
|
||||||
|
(https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html)
|
||||||
|
3. Check for CUDA errors
|
||||||
|
https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api
|
||||||
|
# bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
nGpus = ctypes.c_int()
|
||||||
|
cc_major = ctypes.c_int()
|
||||||
|
cc_minor = ctypes.c_int()
|
||||||
|
|
||||||
|
device = ctypes.c_int()
|
||||||
|
|
||||||
|
check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus)))
|
||||||
|
ccs = []
|
||||||
|
for i in range(nGpus.value):
|
||||||
|
check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i))
|
||||||
|
ref_major = ctypes.byref(cc_major)
|
||||||
|
ref_minor = ctypes.byref(cc_minor)
|
||||||
|
# 2. call extern C function to determine CC
|
||||||
|
check_cuda_result(
|
||||||
|
cuda, cuda.cuDeviceComputeCapability(ref_major, ref_minor, device)
|
||||||
|
)
|
||||||
|
ccs.append(f"{cc_major.value}.{cc_minor.value}")
|
||||||
|
|
||||||
|
return ccs
|
||||||
|
|
||||||
|
|
||||||
|
# def get_compute_capability()-> Union[List[str, ...], None]: # FIXME: error
|
||||||
|
def get_compute_capability(cuda):
|
||||||
|
"""
|
||||||
|
Extracts the highest compute capbility from all available GPUs, as compute
|
||||||
|
capabilities are downwards compatible. If no GPUs are detected, it returns
|
||||||
|
None.
|
||||||
|
"""
|
||||||
|
ccs = get_compute_capabilities(cuda)
|
||||||
|
if ccs is not None:
|
||||||
|
# TODO: handle different compute capabilities; for now, take the max
|
||||||
|
return ccs[-1]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_cuda_setup():
|
||||||
|
print('')
|
||||||
|
print('='*35 + 'BUG REPORT' + '='*35)
|
||||||
|
print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')
|
||||||
|
print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
|
||||||
|
print('='*80)
|
||||||
|
return "libbitsandbytes_cuda116.dll" # $$$
|
||||||
|
|
||||||
|
binary_name = "libbitsandbytes_cpu.so"
|
||||||
|
#if not torch.cuda.is_available():
|
||||||
|
#print('No GPU detected. Loading CPU library...')
|
||||||
|
#return binary_name
|
||||||
|
|
||||||
|
cudart_path = determine_cuda_runtime_lib_path()
|
||||||
|
if cudart_path is None:
|
||||||
|
print(
|
||||||
|
"WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!"
|
||||||
|
)
|
||||||
|
return binary_name
|
||||||
|
|
||||||
|
print(f"CUDA SETUP: CUDA runtime path found: {cudart_path}")
|
||||||
|
cuda = get_cuda_lib_handle()
|
||||||
|
cc = get_compute_capability(cuda)
|
||||||
|
print(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}")
|
||||||
|
cuda_version_string = get_cuda_version(cuda, cudart_path)
|
||||||
|
|
||||||
|
|
||||||
|
if cc == '':
|
||||||
|
print(
|
||||||
|
"WARNING: No GPU detected! Check your CUDA paths. Processing to load CPU-only library..."
|
||||||
|
)
|
||||||
|
return binary_name
|
||||||
|
|
||||||
|
# 7.5 is the minimum CC vor cublaslt
|
||||||
|
has_cublaslt = cc in ["7.5", "8.0", "8.6"]
|
||||||
|
|
||||||
|
# TODO:
|
||||||
|
# (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible)
|
||||||
|
# (2) Multiple CUDA versions installed
|
||||||
|
|
||||||
|
# we use ls -l instead of nvcc to determine the cuda version
|
||||||
|
# since most installations will have the libcudart.so installed, but not the compiler
|
||||||
|
print(f'CUDA SETUP: Detected CUDA version {cuda_version_string}')
|
||||||
|
|
||||||
|
def get_binary_name():
|
||||||
|
"if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so"
|
||||||
|
bin_base_name = "libbitsandbytes_cuda"
|
||||||
|
if has_cublaslt:
|
||||||
|
return f"{bin_base_name}{cuda_version_string}.so"
|
||||||
|
else:
|
||||||
|
return f"{bin_base_name}{cuda_version_string}_nocublaslt.so"
|
||||||
|
|
||||||
|
binary_name = get_binary_name()
|
||||||
|
|
||||||
|
return binary_name
|
BIN
bitsandbytes_windows/libbitsandbytes_cpu.dll
Normal file
BIN
bitsandbytes_windows/libbitsandbytes_cpu.dll
Normal file
Binary file not shown.
BIN
bitsandbytes_windows/libbitsandbytes_cuda116.dll
Normal file
BIN
bitsandbytes_windows/libbitsandbytes_cuda116.dll
Normal file
Binary file not shown.
6
bitsandbytes_windows/nn/__init__.py
Normal file
6
bitsandbytes_windows/nn/__init__.py
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the MIT license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
from .modules import Int8Params, Linear8bit, Linear8bitLt
|
||||||
|
from .modules import Embedding as StableEmbedding
|
|
@ -9,6 +9,7 @@ import torch.nn.utils.spectral_norm as SpectralNorm
|
||||||
from math import sqrt
|
from math import sqrt
|
||||||
|
|
||||||
from utils.util import checkpoint
|
from utils.util import checkpoint
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
|
@ -73,7 +74,7 @@ def initialize_weights(net_l, scale=1):
|
||||||
m.weight.data *= scale # for residual block
|
m.weight.data *= scale # for residual block
|
||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
m.bias.data.zero_()
|
m.bias.data.zero_()
|
||||||
elif isinstance(m, nn.Linear):
|
elif isinstance(m, ml.Linear):
|
||||||
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||||
m.weight.data *= scale
|
m.weight.data *= scale
|
||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
|
@ -108,7 +109,7 @@ def default_init_weights(module, scale=1):
|
||||||
if isinstance(m, nn.Conv2d):
|
if isinstance(m, nn.Conv2d):
|
||||||
kaiming_init(m, a=0, mode='fan_in', bias=0)
|
kaiming_init(m, a=0, mode='fan_in', bias=0)
|
||||||
m.weight.data *= scale
|
m.weight.data *= scale
|
||||||
elif isinstance(m, nn.Linear):
|
elif isinstance(m, ml.Linear):
|
||||||
kaiming_init(m, a=0, mode='fan_in', bias=0)
|
kaiming_init(m, a=0, mode='fan_in', bias=0)
|
||||||
m.weight.data *= scale
|
m.weight.data *= scale
|
||||||
|
|
||||||
|
@ -141,7 +142,7 @@ def linear(*args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Create a linear module.
|
Create a linear module.
|
||||||
"""
|
"""
|
||||||
return nn.Linear(*args, **kwargs)
|
return ml.Linear(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def avg_pool_nd(dims, *args, **kwargs):
|
def avg_pool_nd(dims, *args, **kwargs):
|
||||||
|
|
|
@ -9,6 +9,7 @@ from data.audio.unsupervised_audio_dataset import load_audio
|
||||||
from models.audio.tts.tacotron2.text import sequence_to_text
|
from models.audio.tts.tacotron2.text import sequence_to_text
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
|
|
||||||
def only_letters(string):
|
def only_letters(string):
|
||||||
|
@ -51,7 +52,7 @@ class Wav2VecWrapper(nn.Module):
|
||||||
self.w2v = Wav2Vec2ForCTC.from_pretrained(basis_model)
|
self.w2v = Wav2Vec2ForCTC.from_pretrained(basis_model)
|
||||||
# Perform some surgery to get the model we actually want.
|
# Perform some surgery to get the model we actually want.
|
||||||
self.w2v.wav2vec2.encoder.gradient_checkpointing = checkpointing_enabled
|
self.w2v.wav2vec2.encoder.gradient_checkpointing = checkpointing_enabled
|
||||||
self.w2v.lm_head = nn.Linear(self.w2v.config.hidden_size, vocab_size)
|
self.w2v.lm_head = ml.Linear(self.w2v.config.hidden_size, vocab_size)
|
||||||
self.w2v.config.vocab_size = vocab_size
|
self.w2v.config.vocab_size = vocab_size
|
||||||
self.w2v.config.pad_token_id = 0
|
self.w2v.config.pad_token_id = 0
|
||||||
self.w2v.config.ctc_loss_reduction = 'sum'
|
self.w2v.config.ctc_loss_reduction = 'sum'
|
||||||
|
|
|
@ -5,6 +5,7 @@ import torch.nn as nn
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
from typing import Type, Any, Callable, Union, List, Optional
|
from typing import Type, Any, Callable, Union, List, Optional
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
||||||
|
@ -172,7 +173,7 @@ class ResNet(nn.Module):
|
||||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=4,
|
self.layer4 = self._make_layer(block, 512, layers[3], stride=4,
|
||||||
dilate=replace_stride_with_dilation[2])
|
dilate=replace_stride_with_dilation[2])
|
||||||
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
||||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
self.fc = ml.Linear(512 * block.expansion, num_classes)
|
||||||
|
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, nn.Conv1d):
|
if isinstance(m, nn.Conv1d):
|
||||||
|
|
|
@ -15,13 +15,14 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||||
from models.arch_util import ResBlock
|
from models.arch_util import ResBlock
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import checkpoint
|
from utils.util import checkpoint
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
|
|
||||||
class Mel2Vec2FeatureProjection(nn.Module):
|
class Mel2Vec2FeatureProjection(nn.Module):
|
||||||
def __init__(self, inner_dim, dropout):
|
def __init__(self, inner_dim, dropout):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layer_norm = nn.LayerNorm(inner_dim, eps=1e-5)
|
self.layer_norm = nn.LayerNorm(inner_dim, eps=1e-5)
|
||||||
self.projection = nn.Linear(inner_dim, inner_dim)
|
self.projection = ml.Linear(inner_dim, inner_dim)
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
|
@ -58,10 +59,10 @@ class Wav2Vec2Attention(nn.Module):
|
||||||
self.scaling = self.head_dim**-0.5
|
self.scaling = self.head_dim**-0.5
|
||||||
self.is_decoder = is_decoder
|
self.is_decoder = is_decoder
|
||||||
|
|
||||||
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
self.k_proj = ml.Linear(embed_dim, embed_dim, bias=bias)
|
||||||
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
self.v_proj = ml.Linear(embed_dim, embed_dim, bias=bias)
|
||||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
self.q_proj = ml.Linear(embed_dim, embed_dim, bias=bias)
|
||||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
self.out_proj = ml.Linear(embed_dim, embed_dim, bias=bias)
|
||||||
|
|
||||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||||
|
@ -182,10 +183,10 @@ class Wav2Vec2FeedForward(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.intermediate_dropout = nn.Dropout(dropout)
|
self.intermediate_dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
self.intermediate_dense = nn.Linear(hidden_size, intermediate_size)
|
self.intermediate_dense = ml.Linear(hidden_size, intermediate_size)
|
||||||
self.intermediate_act_fn = F.gelu
|
self.intermediate_act_fn = F.gelu
|
||||||
|
|
||||||
self.output_dense = nn.Linear(intermediate_size, hidden_size)
|
self.output_dense = ml.Linear(intermediate_size, hidden_size)
|
||||||
self.output_dropout = nn.Dropout(dropout)
|
self.output_dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
|
@ -429,7 +430,7 @@ class Mel2Vec(nn.Module):
|
||||||
k = math.sqrt(1 / module.projection.in_features)
|
k = math.sqrt(1 / module.projection.in_features)
|
||||||
nn.init.uniform_(module.projection.weight, a=-k, b=k)
|
nn.init.uniform_(module.projection.weight, a=-k, b=k)
|
||||||
nn.init.uniform_(module.projection.bias, a=-k, b=k)
|
nn.init.uniform_(module.projection.bias, a=-k, b=k)
|
||||||
elif isinstance(module, nn.Linear):
|
elif isinstance(module, ml.Linear):
|
||||||
if self.disable_custom_linear_init:
|
if self.disable_custom_linear_init:
|
||||||
return
|
return
|
||||||
module.weight.data.normal_(mean=0.0, std=self.linear_init_scale)
|
module.weight.data.normal_(mean=0.0, std=self.linear_init_scale)
|
||||||
|
@ -510,7 +511,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
||||||
self.codevectors = nn.Parameter(
|
self.codevectors = nn.Parameter(
|
||||||
torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups)
|
torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups)
|
||||||
)
|
)
|
||||||
self.weight_proj = nn.Linear(proj_dim, self.num_groups * self.num_vars)
|
self.weight_proj = ml.Linear(proj_dim, self.num_groups * self.num_vars)
|
||||||
|
|
||||||
# can be decayed for training
|
# can be decayed for training
|
||||||
self.temperature = 2
|
self.temperature = 2
|
||||||
|
@ -606,8 +607,8 @@ class ContrastiveTrainingWrapper(nn.Module):
|
||||||
self.inp_length_factor = inp_length_multiplier
|
self.inp_length_factor = inp_length_multiplier
|
||||||
|
|
||||||
# make sure that project_hid & project_q are initialized like normal linear layers
|
# make sure that project_hid & project_q are initialized like normal linear layers
|
||||||
self.project_hid = nn.Linear(inner_dim, self.quantizer.codevector_dim)
|
self.project_hid = ml.Linear(inner_dim, self.quantizer.codevector_dim)
|
||||||
self.project_q = nn.Linear(self.quantizer.codevector_dim, self.quantizer.codevector_dim)
|
self.project_q = ml.Linear(self.quantizer.codevector_dim, self.quantizer.codevector_dim)
|
||||||
|
|
||||||
self.reconstruction = do_reconstruction_loss
|
self.reconstruction = do_reconstruction_loss
|
||||||
if do_reconstruction_loss:
|
if do_reconstruction_loss:
|
||||||
|
|
|
@ -2,6 +2,7 @@ import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import GPT2Config, GPT2Model
|
from transformers import GPT2Config, GPT2Model
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
from models.arch_util import AttentionBlock, ResBlock
|
from models.arch_util import AttentionBlock, ResBlock
|
||||||
from models.audio.tts.lucidrains_dvae import DiscreteVAE
|
from models.audio.tts.lucidrains_dvae import DiscreteVAE
|
||||||
|
@ -55,8 +56,9 @@ class ConditioningAR(nn.Module):
|
||||||
self.gpt = GPT2Model(self.config)
|
self.gpt = GPT2Model(self.config)
|
||||||
del self.gpt.wte # Unused, we'll do our own embeddings.
|
del self.gpt.wte # Unused, we'll do our own embeddings.
|
||||||
|
|
||||||
self.embeddings = nn.Embedding(num_vectors, dim)
|
# nn.Embedding
|
||||||
self.head = nn.Linear(dim, num_vectors)
|
self.embeddings = ml.Embedding(num_vectors, dim)
|
||||||
|
self.head = ml.Linear(dim, num_vectors)
|
||||||
|
|
||||||
def forward(self, cheater_codes, conditioning, code_lengths=None, return_latent=False):
|
def forward(self, cheater_codes, conditioning, code_lengths=None, return_latent=False):
|
||||||
unused_params = []
|
unused_params = []
|
||||||
|
|
|
@ -17,6 +17,7 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
from math import sqrt
|
from math import sqrt
|
||||||
|
|
||||||
|
@ -24,7 +25,7 @@ from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
|
|
||||||
Linear = nn.Linear
|
Linear = ml.Linear
|
||||||
ConvTranspose2d = nn.ConvTranspose2d
|
ConvTranspose2d = nn.ConvTranspose2d
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import autocast
|
from torch import autocast
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
from models.arch_util import ResBlock
|
from models.arch_util import ResBlock
|
||||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||||
|
@ -22,7 +23,8 @@ def is_sequence(t):
|
||||||
class MultiGroupEmbedding(nn.Module):
|
class MultiGroupEmbedding(nn.Module):
|
||||||
def __init__(self, tokens, groups, dim):
|
def __init__(self, tokens, groups, dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)])
|
# nn.Embedding
|
||||||
|
self.m = nn.ModuleList([ml.Embedding(tokens, dim // groups) for _ in range(groups)])
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
|
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
|
||||||
|
@ -158,7 +160,8 @@ class FlatDiffusion(nn.Module):
|
||||||
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
|
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
|
||||||
# transformer network.
|
# transformer network.
|
||||||
if in_groups is None:
|
if in_groups is None:
|
||||||
self.embeddings = nn.Embedding(token_count, model_channels)
|
# nn.Embedding
|
||||||
|
self.embeddings = ml.Embedding(token_count, model_channels)
|
||||||
else:
|
else:
|
||||||
self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels)
|
self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels)
|
||||||
self.latent_conditioner = nn.Sequential(
|
self.latent_conditioner = nn.Sequential(
|
||||||
|
|
|
@ -2,6 +2,7 @@ import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers import GPT2Config, GPT2Model
|
from transformers import GPT2Config, GPT2Model
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
from models.arch_util import AttentionBlock, ResBlock
|
from models.arch_util import AttentionBlock, ResBlock
|
||||||
from models.audio.music.music_quantizer import MusicQuantizer
|
from models.audio.music.music_quantizer import MusicQuantizer
|
||||||
|
@ -136,8 +137,9 @@ class GptMusicLower(nn.Module):
|
||||||
self.gpt = GPT2Model(self.config)
|
self.gpt = GPT2Model(self.config)
|
||||||
del self.gpt.wte # Unused, we'll do our own embeddings.
|
del self.gpt.wte # Unused, we'll do our own embeddings.
|
||||||
|
|
||||||
self.embeddings = nn.ModuleList([nn.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)])
|
# nn.Embedding
|
||||||
self.heads = nn.ModuleList([nn.Linear(dim, num_target_vectors) for _ in range(num_vaes)])
|
self.embeddings = nn.ModuleList([ml.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)])
|
||||||
|
self.heads = nn.ModuleList([ml.Linear(dim, num_target_vectors) for _ in range(num_vaes)])
|
||||||
|
|
||||||
def forward(self, mel, conditioning, return_latent=False):
|
def forward(self, mel, conditioning, return_latent=False):
|
||||||
unused_params = []
|
unused_params = []
|
||||||
|
@ -238,8 +240,9 @@ class GptMusicUpper(nn.Module):
|
||||||
self.gpt = GPT2Model(self.config)
|
self.gpt = GPT2Model(self.config)
|
||||||
del self.gpt.wte # Unused, we'll do our own embeddings.
|
del self.gpt.wte # Unused, we'll do our own embeddings.
|
||||||
|
|
||||||
self.embeddings = nn.ModuleList([nn.Embedding(num_upper_vectors, dim // num_upper_groups) for _ in range(num_upper_groups)])
|
# nn.Embedding
|
||||||
self.heads = nn.ModuleList([nn.Linear(dim, num_upper_vectors) for _ in range(num_upper_groups)])
|
self.embeddings = nn.ModuleList([ml.Embedding(num_upper_vectors, dim // num_upper_groups) for _ in range(num_upper_groups)])
|
||||||
|
self.heads = nn.ModuleList([ml.Linear(dim, num_upper_vectors) for _ in range(num_upper_groups)])
|
||||||
|
|
||||||
|
|
||||||
def forward(self, mel, conditioning, return_latent=False):
|
def forward(self, mel, conditioning, return_latent=False):
|
||||||
|
|
|
@ -2,6 +2,7 @@ import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import GPT2Config, GPT2Model
|
from transformers import GPT2Config, GPT2Model
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
from models.arch_util import AttentionBlock, ResBlock
|
from models.arch_util import AttentionBlock, ResBlock
|
||||||
from models.audio.tts.lucidrains_dvae import DiscreteVAE
|
from models.audio.tts.lucidrains_dvae import DiscreteVAE
|
||||||
|
@ -73,8 +74,9 @@ class GptMusicLower(nn.Module):
|
||||||
self.gpt = GPT2Model(self.config)
|
self.gpt = GPT2Model(self.config)
|
||||||
del self.gpt.wte # Unused, we'll do our own embeddings.
|
del self.gpt.wte # Unused, we'll do our own embeddings.
|
||||||
|
|
||||||
self.embeddings = nn.ModuleList([nn.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)])
|
# nn.Embedding
|
||||||
self.heads = nn.ModuleList([nn.Linear(dim, num_target_vectors) for _ in range(num_vaes)])
|
self.embeddings = nn.ModuleList([ml.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)])
|
||||||
|
self.heads = nn.ModuleList([ml.Linear(dim, num_target_vectors) for _ in range(num_vaes)])
|
||||||
|
|
||||||
def forward(self, mel, return_latent=False):
|
def forward(self, mel, return_latent=False):
|
||||||
unused_params = []
|
unused_params = []
|
||||||
|
|
|
@ -3,6 +3,7 @@ import functools
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
from models.diffusion.nn import timestep_embedding
|
from models.diffusion.nn import timestep_embedding
|
||||||
from models.lucidrains.vq import VectorQuantize
|
from models.lucidrains.vq import VectorQuantize
|
||||||
|
@ -21,8 +22,8 @@ class SelfClassifyingHead(nn.Module):
|
||||||
use_rmsnorm=True, ff_glu=True, do_checkpointing=False)
|
use_rmsnorm=True, ff_glu=True, do_checkpointing=False)
|
||||||
self.quantizer = VectorQuantize(out_dim, classes, use_cosine_sim=False, threshold_ema_dead_code=2,
|
self.quantizer = VectorQuantize(out_dim, classes, use_cosine_sim=False, threshold_ema_dead_code=2,
|
||||||
sample_codebook_temp=init_temperature)
|
sample_codebook_temp=init_temperature)
|
||||||
self.to_output = nn.Linear(dim, out_dim)
|
self.to_output = ml.Linear(dim, out_dim)
|
||||||
self.to_decoder = nn.Linear(out_dim, dim)
|
self.to_decoder = ml.Linear(out_dim, dim)
|
||||||
|
|
||||||
def do_ar_step(self, x, used_codes):
|
def do_ar_step(self, x, used_codes):
|
||||||
h = self.dec(x)
|
h = self.dec(x)
|
||||||
|
@ -90,7 +91,7 @@ class InstrumentQuantizer(nn.Module):
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.op_dim = op_dim
|
self.op_dim = op_dim
|
||||||
self.proj = nn.Linear(op_dim, dim)
|
self.proj = ml.Linear(op_dim, dim)
|
||||||
self.encoder = nn.ModuleList([VectorResBlock(dim, dropout) for _ in range(enc_depth)])
|
self.encoder = nn.ModuleList([VectorResBlock(dim, dropout) for _ in range(enc_depth)])
|
||||||
self.heads = SelfClassifyingHead(dim, num_classes, op_dim, head_depth, class_seq_len, dropout, max_temp)
|
self.heads = SelfClassifyingHead(dim, num_classes, op_dim, head_depth, class_seq_len, dropout, max_temp)
|
||||||
self.min_gumbel_temperature = min_temp
|
self.min_gumbel_temperature = min_temp
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import torch_intermediary as ml
|
||||||
from transformers import GPT2Config, GPT2Model
|
from transformers import GPT2Config, GPT2Model
|
||||||
|
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
|
@ -17,8 +18,9 @@ class Mel2VecCodesGpt(nn.Module):
|
||||||
n_inner=dim*2)
|
n_inner=dim*2)
|
||||||
self.gpt = GPT2Model(self.config)
|
self.gpt = GPT2Model(self.config)
|
||||||
del self.gpt.wte # Unused, we'll do our own embeddings.
|
del self.gpt.wte # Unused, we'll do our own embeddings.
|
||||||
self.embeddings = nn.ModuleList([nn.Embedding(num_vectors, dim//num_groups) for _ in range(num_groups)])
|
# nn.Embedding
|
||||||
self.heads = nn.ModuleList([nn.Linear(dim, num_vectors) for _ in range(num_groups)])
|
self.embeddings = nn.ModuleList([ml.Embedding(num_vectors, dim//num_groups) for _ in range(num_groups)])
|
||||||
|
self.heads = nn.ModuleList([ml.Linear(dim, num_vectors) for _ in range(num_groups)])
|
||||||
|
|
||||||
def forward(self, codes):
|
def forward(self, codes):
|
||||||
assert codes.shape[-1] == self.num_groups
|
assert codes.shape[-1] == self.num_groups
|
||||||
|
|
|
@ -3,6 +3,7 @@ import functools
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
from models.arch_util import zero_module
|
from models.arch_util import zero_module
|
||||||
from models.vqvae.vqvae import Quantize
|
from models.vqvae.vqvae import Quantize
|
||||||
|
@ -75,7 +76,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
||||||
self.codevectors = nn.Parameter(
|
self.codevectors = nn.Parameter(
|
||||||
torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups)
|
torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups)
|
||||||
)
|
)
|
||||||
self.weight_proj = nn.Linear(proj_dim, self.num_groups * self.num_vars)
|
self.weight_proj = ml.Linear(proj_dim, self.num_groups * self.num_vars)
|
||||||
|
|
||||||
# can be decayed for training
|
# can be decayed for training
|
||||||
self.temperature = 2
|
self.temperature = 2
|
||||||
|
|
|
@ -3,6 +3,7 @@ import functools
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
from models.arch_util import zero_module
|
from models.arch_util import zero_module
|
||||||
from models.vqvae.vqvae import Quantize
|
from models.vqvae.vqvae import Quantize
|
||||||
|
@ -87,7 +88,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
||||||
self.codevectors = nn.Parameter(
|
self.codevectors = nn.Parameter(
|
||||||
torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups)
|
torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups)
|
||||||
)
|
)
|
||||||
self.weight_proj = nn.Linear(proj_dim, self.num_groups * self.num_vars)
|
self.weight_proj = ml.Linear(proj_dim, self.num_groups * self.num_vars)
|
||||||
|
|
||||||
# can be decayed for training
|
# can be decayed for training
|
||||||
self.temperature = 2
|
self.temperature = 2
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
@ -7,6 +8,7 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchaudio
|
import torchaudio
|
||||||
import torchvision
|
import torchvision
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||||
from models.diffusion.unet_diffusion import TimestepBlock
|
from models.diffusion.unet_diffusion import TimestepBlock
|
||||||
|
@ -54,12 +56,12 @@ class ConcatAttentionBlock(TimestepBlock):
|
||||||
self.prenorm = RMSScaleShiftNorm(trunk_dim, embed_dim=time_embed_dim, bias=False)
|
self.prenorm = RMSScaleShiftNorm(trunk_dim, embed_dim=time_embed_dim, bias=False)
|
||||||
if cond_projection:
|
if cond_projection:
|
||||||
self.tdim = trunk_dim+cond_dim_hidden
|
self.tdim = trunk_dim+cond_dim_hidden
|
||||||
self.cond_project = nn.Linear(cond_dim_in, cond_dim_hidden)
|
self.cond_project = ml.Linear(cond_dim_in, cond_dim_hidden)
|
||||||
else:
|
else:
|
||||||
self.tdim = trunk_dim
|
self.tdim = trunk_dim
|
||||||
self.block1 = SubBlock(self.tdim, contraction_dim, heads, dropout, use_conv)
|
self.block1 = SubBlock(self.tdim, contraction_dim, heads, dropout, use_conv)
|
||||||
self.block2 = SubBlock(self.tdim+contraction_dim*2, contraction_dim, heads, dropout, use_conv)
|
self.block2 = SubBlock(self.tdim+contraction_dim*2, contraction_dim, heads, dropout, use_conv)
|
||||||
self.out = nn.Linear(contraction_dim*4, trunk_dim, bias=False)
|
self.out = ml.Linear(contraction_dim*4, trunk_dim, bias=False)
|
||||||
self.out.weight.data.zero_()
|
self.out.weight.data.zero_()
|
||||||
|
|
||||||
def forward(self, x, cond, timestep_emb, rotary_emb):
|
def forward(self, x, cond, timestep_emb, rotary_emb):
|
||||||
|
@ -87,7 +89,7 @@ class ConditioningEncoder(nn.Module):
|
||||||
self.init = nn.Conv1d(cond_dim, embedding_dim, kernel_size=1)
|
self.init = nn.Conv1d(cond_dim, embedding_dim, kernel_size=1)
|
||||||
self.time_proj = time_proj
|
self.time_proj = time_proj
|
||||||
if time_proj:
|
if time_proj:
|
||||||
self.time_proj = nn.Linear(time_embed_dim, embedding_dim)
|
self.time_proj = ml.Linear(time_embed_dim, embedding_dim)
|
||||||
self.attn = Encoder(
|
self.attn = Encoder(
|
||||||
dim=embedding_dim,
|
dim=embedding_dim,
|
||||||
depth=attn_blocks,
|
depth=attn_blocks,
|
||||||
|
|
|
@ -4,6 +4,7 @@ from time import time
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
from models.arch_util import ResBlock
|
from models.arch_util import ResBlock
|
||||||
from models.audio.music.gpt_music2 import UpperEncoder, GptMusicLower
|
from models.audio.music.gpt_music2 import UpperEncoder, GptMusicLower
|
||||||
|
@ -27,7 +28,8 @@ def is_sequence(t):
|
||||||
class MultiGroupEmbedding(nn.Module):
|
class MultiGroupEmbedding(nn.Module):
|
||||||
def __init__(self, tokens, groups, dim):
|
def __init__(self, tokens, groups, dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)])
|
# nn.Embedding
|
||||||
|
self.m = nn.ModuleList([ml.Embedding(tokens, dim // groups) for _ in range(groups)])
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
|
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
|
||||||
|
@ -68,7 +70,7 @@ class ConcatAttentionBlock(TimestepBlock):
|
||||||
self.prenorm = RMSScaleShiftNorm(trunk_dim, embed_dim=time_embed_dim, bias=False)
|
self.prenorm = RMSScaleShiftNorm(trunk_dim, embed_dim=time_embed_dim, bias=False)
|
||||||
self.block1 = SubBlock(trunk_dim, contraction_dim, heads, dropout)
|
self.block1 = SubBlock(trunk_dim, contraction_dim, heads, dropout)
|
||||||
self.block2 = SubBlock(trunk_dim+contraction_dim*2, contraction_dim, heads, dropout)
|
self.block2 = SubBlock(trunk_dim+contraction_dim*2, contraction_dim, heads, dropout)
|
||||||
self.out = nn.Linear(contraction_dim*4, trunk_dim, bias=False)
|
self.out = ml.Linear(contraction_dim*4, trunk_dim, bias=False)
|
||||||
self.out.weight.data.zero_()
|
self.out.weight.data.zero_()
|
||||||
|
|
||||||
def forward(self, x, timestep_emb, rotary_emb):
|
def forward(self, x, timestep_emb, rotary_emb):
|
||||||
|
@ -129,7 +131,7 @@ class TransformerDiffusion(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
prenet_heads = prenet_channels//64
|
prenet_heads = prenet_channels//64
|
||||||
self.input_converter = nn.Linear(input_vec_dim, prenet_channels)
|
self.input_converter = ml.Linear(input_vec_dim, prenet_channels)
|
||||||
self.code_converter = Encoder(
|
self.code_converter = Encoder(
|
||||||
dim=prenet_channels,
|
dim=prenet_channels,
|
||||||
depth=prenet_layers,
|
depth=prenet_layers,
|
||||||
|
@ -145,7 +147,7 @@ class TransformerDiffusion(nn.Module):
|
||||||
|
|
||||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels))
|
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels))
|
||||||
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
|
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
|
||||||
self.intg = nn.Linear(prenet_channels*2, model_channels)
|
self.intg = ml.Linear(prenet_channels*2, model_channels)
|
||||||
self.layers = TimestepRotaryEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_embed_dim, num_heads, dropout) for _ in range(num_layers)])
|
self.layers = TimestepRotaryEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_embed_dim, num_heads, dropout) for _ in range(num_layers)])
|
||||||
|
|
||||||
self.out = nn.Sequential(
|
self.out = nn.Sequential(
|
||||||
|
|
|
@ -5,6 +5,7 @@ from random import randrange
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask, cGLU, \
|
from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask, cGLU, \
|
||||||
RelativeQKBias
|
RelativeQKBias
|
||||||
|
@ -69,13 +70,14 @@ class ConditioningEncoder(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
attn = []
|
attn = []
|
||||||
self.init = nn.Conv1d(spec_dim, hidden_dim, kernel_size=5, stride=2)
|
self.init = nn.Conv1d(spec_dim, hidden_dim, kernel_size=5, stride=2)
|
||||||
self.resolution_embedding = nn.Embedding(num_resolutions, hidden_dim)
|
# nn.Embedding
|
||||||
|
self.resolution_embedding = ml.Embedding(num_resolutions, hidden_dim)
|
||||||
self.resolution_embedding.weight.data.mul(.1) # Reduces the relative influence of this embedding from the start.
|
self.resolution_embedding.weight.data.mul(.1) # Reduces the relative influence of this embedding from the start.
|
||||||
for a in range(attn_blocks):
|
for a in range(attn_blocks):
|
||||||
attn.append(AttentionBlock(hidden_dim, num_attn_heads, do_checkpoint=do_checkpointing))
|
attn.append(AttentionBlock(hidden_dim, num_attn_heads, do_checkpoint=do_checkpointing))
|
||||||
attn.append(ResBlock(hidden_dim, dims=1, checkpointing_enabled=do_checkpointing))
|
attn.append(ResBlock(hidden_dim, dims=1, checkpointing_enabled=do_checkpointing))
|
||||||
self.attn = nn.Sequential(*attn)
|
self.attn = nn.Sequential(*attn)
|
||||||
self.out = nn.Linear(hidden_dim, out_dim, bias=False)
|
self.out = ml.Linear(hidden_dim, out_dim, bias=False)
|
||||||
self.dim = hidden_dim
|
self.dim = hidden_dim
|
||||||
self.do_checkpointing = do_checkpointing
|
self.do_checkpointing = do_checkpointing
|
||||||
|
|
||||||
|
@ -131,7 +133,8 @@ class TransformerDiffusion(nn.Module):
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
linear(time_embed_dim, time_proj_dim),
|
linear(time_embed_dim, time_proj_dim),
|
||||||
)
|
)
|
||||||
self.resolution_embed = nn.Embedding(resolution_steps, time_proj_dim)
|
# nn.Embedding
|
||||||
|
self.resolution_embed = ml.Embedding(resolution_steps, time_proj_dim)
|
||||||
self.conditioning_encoder = ConditioningEncoder(in_channels, model_channels, cond_proj_dim, resolution_steps, num_attn_heads=model_channels//64)
|
self.conditioning_encoder = ConditioningEncoder(in_channels, model_channels, cond_proj_dim, resolution_steps, num_attn_heads=model_channels//64)
|
||||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,cond_proj_dim))
|
self.unconditioned_embedding = nn.Parameter(torch.randn(1,cond_proj_dim))
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,7 @@ import torch as th
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchvision # For debugging, not actually used.
|
import torchvision # For debugging, not actually used.
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
from models.audio.music.gpt_music import GptMusicLower
|
from models.audio.music.gpt_music import GptMusicLower
|
||||||
from models.audio.music.music_quantizer import MusicQuantizer
|
from models.audio.music.music_quantizer import MusicQuantizer
|
||||||
|
@ -490,7 +491,7 @@ class UNetMusicModel(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.ar_prior:
|
if self.ar_prior:
|
||||||
self.ar_input = nn.Linear(input_vec_dim, model_channels)
|
self.ar_input = ml.Linear(input_vec_dim, model_channels)
|
||||||
self.ar_prior_intg = Encoder(
|
self.ar_prior_intg = Encoder(
|
||||||
dim=model_channels,
|
dim=model_channels,
|
||||||
depth=4,
|
depth=4,
|
||||||
|
@ -504,7 +505,7 @@ class UNetMusicModel(nn.Module):
|
||||||
ff_mult=1,
|
ff_mult=1,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.input_converter = nn.Linear(input_vec_dim, model_channels)
|
self.input_converter = ml.Linear(input_vec_dim, model_channels)
|
||||||
self.code_converter = Encoder(
|
self.code_converter = Encoder(
|
||||||
dim=model_channels,
|
dim=model_channels,
|
||||||
depth=4,
|
depth=4,
|
||||||
|
@ -521,7 +522,8 @@ class UNetMusicModel(nn.Module):
|
||||||
self.x_processor = conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
self.x_processor = conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
||||||
|
|
||||||
if self.num_classes is not None:
|
if self.num_classes is not None:
|
||||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
# nn.Embedding
|
||||||
|
self.label_emb = ml.Embedding(num_classes, time_embed_dim)
|
||||||
self.use_raw_y_as_embedding = use_raw_y_as_embedding
|
self.use_raw_y_as_embedding = use_raw_y_as_embedding
|
||||||
assert not ((self.num_classes is not None) and use_raw_y_as_embedding) # These are mutually-exclusive.
|
assert not ((self.num_classes is not None) and use_raw_y_as_embedding) # These are mutually-exclusive.
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ from random import random
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
from models.audio.tts.unet_diffusion_tts7 import CheckpointedLayer
|
from models.audio.tts.unet_diffusion_tts7 import CheckpointedLayer
|
||||||
from models.lucidrains.x_transformers import Encoder
|
from models.lucidrains.x_transformers import Encoder
|
||||||
|
@ -36,9 +37,12 @@ class CtcCodeGenerator(nn.Module):
|
||||||
self.ctc_codes = ctc_codes
|
self.ctc_codes = ctc_codes
|
||||||
pred_codes = (max_pad+1)*(max_repeat+1)
|
pred_codes = (max_pad+1)*(max_repeat+1)
|
||||||
|
|
||||||
self.position_embedding = nn.Embedding(max_length, model_dim)
|
# nn.Embedding
|
||||||
self.codes_embedding = nn.Embedding(ctc_codes, model_dim)
|
self.position_embedding = ml.Embedding(max_length, model_dim)
|
||||||
self.recursive_embedding = nn.Embedding(pred_codes, model_dim)
|
# nn.Embedding
|
||||||
|
self.codes_embedding = ml.Embedding(ctc_codes, model_dim)
|
||||||
|
# nn.Embedding
|
||||||
|
self.recursive_embedding = ml.Embedding(pred_codes, model_dim)
|
||||||
self.mask_embedding = nn.Parameter(torch.randn(model_dim))
|
self.mask_embedding = nn.Parameter(torch.randn(model_dim))
|
||||||
self.encoder = Encoder(
|
self.encoder = Encoder(
|
||||||
dim=model_dim,
|
dim=model_dim,
|
||||||
|
@ -50,8 +54,8 @@ class CtcCodeGenerator(nn.Module):
|
||||||
ff_glu=True,
|
ff_glu=True,
|
||||||
rotary_pos_emb=True,
|
rotary_pos_emb=True,
|
||||||
)
|
)
|
||||||
self.pred_head = nn.Linear(model_dim, pred_codes)
|
self.pred_head = ml.Linear(model_dim, pred_codes)
|
||||||
self.confidence_head = nn.Linear(model_dim, 1)
|
self.confidence_head = ml.Linear(model_dim, 1)
|
||||||
|
|
||||||
def inference(self, codes, pads, repeats):
|
def inference(self, codes, pads, repeats):
|
||||||
position_h = self.position_embedding(torch.arange(0, codes.shape[-1], device=codes.device))
|
position_h = self.position_embedding(torch.arange(0, codes.shape[-1], device=codes.device))
|
||||||
|
|
|
@ -5,6 +5,8 @@ from functools import partial
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
from x_transformers.x_transformers import groupby_prefix_and_trim, FixedPositionalEmbedding, default, RotaryEmbedding, \
|
from x_transformers.x_transformers import groupby_prefix_and_trim, FixedPositionalEmbedding, default, RotaryEmbedding, \
|
||||||
DEFAULT_DIM_HEAD, RelativePositionBias, LearnedAlibiPositionalBias, AlibiPositionalBias, ScaleNorm, RMSNorm, \
|
DEFAULT_DIM_HEAD, RelativePositionBias, LearnedAlibiPositionalBias, AlibiPositionalBias, ScaleNorm, RMSNorm, \
|
||||||
exists, Attention, FeedForward, Scale, ShiftTokens, GRUGating, Residual, cast_tuple, equals, LayerIntermediates, \
|
exists, Attention, FeedForward, Scale, ShiftTokens, GRUGating, Residual, cast_tuple, equals, LayerIntermediates, \
|
||||||
|
@ -16,7 +18,7 @@ class TimeIntegrationBlock(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.emb_layers = nn.Sequential(
|
self.emb_layers = nn.Sequential(
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Linear(
|
ml.Linear(
|
||||||
time_emb_dim,
|
time_emb_dim,
|
||||||
2 * dim
|
2 * dim
|
||||||
),
|
),
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
|
|
||||||
from models.diffusion.nn import normalization, conv_nd, zero_module
|
from models.diffusion.nn import normalization, conv_nd, zero_module
|
||||||
|
@ -138,7 +139,7 @@ class AudioMiniEncoderWithClassifierHead(nn.Module):
|
||||||
def __init__(self, classes, distribute_zero_label=True, **kwargs):
|
def __init__(self, classes, distribute_zero_label=True, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.enc = AudioMiniEncoder(**kwargs)
|
self.enc = AudioMiniEncoder(**kwargs)
|
||||||
self.head = nn.Linear(self.enc.dim, classes)
|
self.head = ml.Linear(self.enc.dim, classes)
|
||||||
self.num_classes = classes
|
self.num_classes = classes
|
||||||
self.distribute_zero_label = distribute_zero_label
|
self.distribute_zero_label = distribute_zero_label
|
||||||
|
|
||||||
|
@ -183,7 +184,7 @@ class QueryProvidedAttentionBlock(nn.Module):
|
||||||
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||||
self.num_heads = channels // num_head_channels
|
self.num_heads = channels // num_head_channels
|
||||||
self.norm = normalization(channels)
|
self.norm = normalization(channels)
|
||||||
self.q = nn.Linear(channels, channels)
|
self.q = ml.Linear(channels, channels)
|
||||||
self.qnorm = nn.LayerNorm(channels)
|
self.qnorm = nn.LayerNorm(channels)
|
||||||
self.kv = conv_nd(1, channels, channels*2, 1)
|
self.kv = conv_nd(1, channels, channels*2, 1)
|
||||||
if use_new_attention_order:
|
if use_new_attention_order:
|
||||||
|
|
|
@ -3,6 +3,7 @@ import math
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
@ -44,7 +45,7 @@ class RandomLatentConverter(nn.Module):
|
||||||
def __init__(self, channels):
|
def __init__(self, channels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = nn.Sequential(*[EqualLinear(channels, channels, lr_mul=.1) for _ in range(5)],
|
self.layers = nn.Sequential(*[EqualLinear(channels, channels, lr_mul=.1) for _ in range(5)],
|
||||||
nn.Linear(channels, channels))
|
ml.Linear(channels, channels))
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
|
|
||||||
def forward(self, ref):
|
def forward(self, ref):
|
||||||
|
|
|
@ -3,12 +3,13 @@ from librosa.filters import mel as librosa_mel_fn
|
||||||
from models.audio.tts.tacotron2.audio_processing import dynamic_range_compression
|
from models.audio.tts.tacotron2.audio_processing import dynamic_range_compression
|
||||||
from models.audio.tts.tacotron2.audio_processing import dynamic_range_decompression
|
from models.audio.tts.tacotron2.audio_processing import dynamic_range_decompression
|
||||||
from models.audio.tts.tacotron2.stft import STFT
|
from models.audio.tts.tacotron2.stft import STFT
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
|
|
||||||
class LinearNorm(torch.nn.Module):
|
class LinearNorm(torch.nn.Module):
|
||||||
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
|
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
|
||||||
super(LinearNorm, self).__init__()
|
super(LinearNorm, self).__init__()
|
||||||
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
self.linear_layer = torch.ml.Linear(in_dim, out_dim, bias=bias)
|
||||||
|
|
||||||
torch.nn.init.xavier_uniform_(
|
torch.nn.init.xavier_uniform_(
|
||||||
self.linear_layer.weight,
|
self.linear_layer.weight,
|
||||||
|
|
|
@ -8,6 +8,7 @@ from models.audio.tts.tacotron2.layers import ConvNorm, LinearNorm
|
||||||
from models.audio.tts.tacotron2.hparams import create_hparams
|
from models.audio.tts.tacotron2.hparams import create_hparams
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths
|
from models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
|
|
||||||
class LocationLayer(nn.Module):
|
class LocationLayer(nn.Module):
|
||||||
|
@ -463,7 +464,8 @@ class Tacotron2(nn.Module):
|
||||||
self.fp16_run = hparams.fp16_run
|
self.fp16_run = hparams.fp16_run
|
||||||
self.n_mel_channels = hparams.n_mel_channels
|
self.n_mel_channels = hparams.n_mel_channels
|
||||||
self.n_frames_per_step = hparams.n_frames_per_step
|
self.n_frames_per_step = hparams.n_frames_per_step
|
||||||
self.embedding = nn.Embedding(
|
# nn.Embedding
|
||||||
|
self.embedding = ml.Embedding(
|
||||||
hparams.n_symbols, hparams.symbols_embedding_dim)
|
hparams.n_symbols, hparams.symbols_embedding_dim)
|
||||||
std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
|
std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
|
||||||
val = sqrt(3.0) * std # uniform bounds for std
|
val = sqrt(3.0) * std # uniform bounds for std
|
||||||
|
|
|
@ -13,6 +13,7 @@ from models.audio.tts.tacotron2.tacotron2 import Attention, Encoder
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths
|
from models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths
|
||||||
from utils.util import checkpoint
|
from utils.util import checkpoint
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -185,7 +186,8 @@ class WaveTacotron2(nn.Module):
|
||||||
self.fp16_run = hparams.fp16_run
|
self.fp16_run = hparams.fp16_run
|
||||||
self.n_mel_channels = hparams.n_mel_channels
|
self.n_mel_channels = hparams.n_mel_channels
|
||||||
self.n_frames_per_step = hparams.n_frames_per_step
|
self.n_frames_per_step = hparams.n_frames_per_step
|
||||||
self.embedding = nn.Embedding(
|
# nn.Embedding
|
||||||
|
self.embedding = ml.Embedding(
|
||||||
hparams.n_symbols, hparams.symbols_embedding_dim)
|
hparams.n_symbols, hparams.symbols_embedding_dim)
|
||||||
std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
|
std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
|
||||||
val = sqrt(3.0) * std # uniform bounds for std
|
val = sqrt(3.0) * std # uniform bounds for std
|
||||||
|
|
|
@ -25,6 +25,7 @@ import random
|
||||||
from time import time
|
from time import time
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch_intermediary as ml
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
@ -35,7 +36,8 @@ def null_position_embeddings(range, dim):
|
||||||
class LearnedPositionEmbeddings(nn.Module):
|
class LearnedPositionEmbeddings(nn.Module):
|
||||||
def __init__(self, seq_len, model_dim, init=.02, relative=False):
|
def __init__(self, seq_len, model_dim, init=.02, relative=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.emb = nn.Embedding(seq_len, model_dim)
|
# nn.Embedding
|
||||||
|
self.emb = ml.Embedding(seq_len, model_dim)
|
||||||
# Initializing this way is standard for GPT-2
|
# Initializing this way is standard for GPT-2
|
||||||
self.emb.weight.data.normal_(mean=0.0, std=init)
|
self.emb.weight.data.normal_(mean=0.0, std=init)
|
||||||
self.relative = relative
|
self.relative = relative
|
||||||
|
|
|
@ -7,6 +7,7 @@ from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlo
|
||||||
from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding
|
from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import checkpoint
|
from utils.util import checkpoint
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
|
|
||||||
def is_latent(t):
|
def is_latent(t):
|
||||||
|
@ -19,7 +20,8 @@ def is_sequence(t):
|
||||||
class MultiGroupEmbedding(nn.Module):
|
class MultiGroupEmbedding(nn.Module):
|
||||||
def __init__(self, tokens, groups, dim):
|
def __init__(self, tokens, groups, dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)])
|
# nn.Embedding
|
||||||
|
self.m = nn.ModuleList([ml.Embedding(tokens, dim // groups) for _ in range(groups)])
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
|
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
|
||||||
|
@ -100,15 +102,17 @@ class TransformerDiffusionTTS(nn.Module):
|
||||||
ff_glu=True,
|
ff_glu=True,
|
||||||
rotary_pos_emb=True,
|
rotary_pos_emb=True,
|
||||||
)
|
)
|
||||||
self.clvp_encoder = nn.Linear(clvp_in_dim, model_channels)
|
self.clvp_encoder = ml.Linear(clvp_in_dim, model_channels)
|
||||||
self.type_embedding = nn.Embedding(types, model_channels)
|
# nn.Embedding
|
||||||
|
self.type_embedding = ml.Embedding(types, model_channels)
|
||||||
|
|
||||||
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
|
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
|
||||||
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
|
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
|
||||||
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
|
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
|
||||||
# transformer network.
|
# transformer network.
|
||||||
if in_groups is None:
|
if in_groups is None:
|
||||||
self.embeddings = nn.Embedding(token_count, model_channels)
|
# nn.Embedding
|
||||||
|
self.embeddings = ml.Embedding(token_count, model_channels)
|
||||||
else:
|
else:
|
||||||
self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels)
|
self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels)
|
||||||
self.latent_conditioner = nn.Sequential(
|
self.latent_conditioner = nn.Sequential(
|
||||||
|
@ -140,7 +144,7 @@ class TransformerDiffusionTTS(nn.Module):
|
||||||
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
|
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
|
||||||
|
|
||||||
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
|
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
|
||||||
self.intg = nn.Linear(model_channels*2, model_channels)
|
self.intg = ml.Linear(model_channels*2, model_channels)
|
||||||
self.layers = TimestepRotaryEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers)])
|
self.layers = TimestepRotaryEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers)])
|
||||||
|
|
||||||
self.out = nn.Sequential(
|
self.out = nn.Sequential(
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||||
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock
|
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock
|
||||||
|
@ -19,7 +20,8 @@ def is_sequence(t):
|
||||||
class MultiGroupEmbedding(nn.Module):
|
class MultiGroupEmbedding(nn.Module):
|
||||||
def __init__(self, tokens, groups, dim):
|
def __init__(self, tokens, groups, dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)])
|
# nn.Embedding
|
||||||
|
self.m = nn.ModuleList([ml.Embedding(tokens, dim // groups) for _ in range(groups)])
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
|
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
|
||||||
|
@ -40,7 +42,7 @@ class DietAttentionBlock(TimestepBlock):
|
||||||
def __init__(self, in_dim, dim, heads, dropout):
|
def __init__(self, in_dim, dim, heads, dropout):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rms_scale_norm = RMSScaleShiftNorm(in_dim)
|
self.rms_scale_norm = RMSScaleShiftNorm(in_dim)
|
||||||
self.proj = nn.Linear(in_dim, dim)
|
self.proj = ml.Linear(in_dim, dim)
|
||||||
self.attn = Attention(dim, heads=heads, causal=False, dropout=dropout)
|
self.attn = Attention(dim, heads=heads, causal=False, dropout=dropout)
|
||||||
self.ff = FeedForward(dim, in_dim, mult=1, dropout=dropout, zero_init_output=True)
|
self.ff = FeedForward(dim, in_dim, mult=1, dropout=dropout, zero_init_output=True)
|
||||||
|
|
||||||
|
@ -105,15 +107,17 @@ class TransformerDiffusionTTS(nn.Module):
|
||||||
ff_glu=True,
|
ff_glu=True,
|
||||||
rotary_pos_emb=True,
|
rotary_pos_emb=True,
|
||||||
)
|
)
|
||||||
self.clvp_encoder = nn.Linear(clvp_in_dim, prenet_channels)
|
self.clvp_encoder = ml.Linear(clvp_in_dim, prenet_channels)
|
||||||
self.type_embedding = nn.Embedding(types, prenet_channels)
|
# nn.Embedding
|
||||||
|
self.type_embedding = ml.Embedding(types, prenet_channels)
|
||||||
|
|
||||||
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
|
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
|
||||||
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
|
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
|
||||||
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
|
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
|
||||||
# transformer network.
|
# transformer network.
|
||||||
if in_groups is None:
|
if in_groups is None:
|
||||||
self.embeddings = nn.Embedding(token_count, prenet_channels)
|
# nn.Embedding
|
||||||
|
self.embeddings = ml.Embedding(token_count, prenet_channels)
|
||||||
else:
|
else:
|
||||||
self.embeddings = MultiGroupEmbedding(token_count, in_groups, prenet_channels)
|
self.embeddings = MultiGroupEmbedding(token_count, in_groups, prenet_channels)
|
||||||
self.latent_conditioner = nn.Sequential(
|
self.latent_conditioner = nn.Sequential(
|
||||||
|
@ -144,8 +148,8 @@ class TransformerDiffusionTTS(nn.Module):
|
||||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels))
|
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels))
|
||||||
|
|
||||||
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
|
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
|
||||||
self.cond_intg = nn.Linear(prenet_channels*4, model_channels)
|
self.cond_intg = ml.Linear(prenet_channels*4, model_channels)
|
||||||
self.intg = nn.Linear(prenet_channels*2, model_channels)
|
self.intg = ml.Linear(prenet_channels*2, model_channels)
|
||||||
|
|
||||||
self.layers = TimestepRotaryEmbedSequential(*[DietAttentionBlock(model_channels, block_channels, block_channels // 64, dropout) for _ in range(num_layers)])
|
self.layers = TimestepRotaryEmbedSequential(*[DietAttentionBlock(model_channels, block_channels, block_channels // 64, dropout) for _ in range(num_layers)])
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import autocast
|
from torch import autocast
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||||
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \
|
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \
|
||||||
|
@ -247,14 +248,16 @@ class DiffusionTts(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_dim = model_channels * 8
|
embedding_dim = model_channels * 8
|
||||||
self.code_embedding = nn.Embedding(num_tokens+1, embedding_dim)
|
# nn.Embedding
|
||||||
|
self.code_embedding = ml.Embedding(num_tokens+1, embedding_dim)
|
||||||
self.contextual_embedder = AudioMiniEncoder(1, embedding_dim, base_channels=32, depth=6, resnet_blocks=1,
|
self.contextual_embedder = AudioMiniEncoder(1, embedding_dim, base_channels=32, depth=6, resnet_blocks=1,
|
||||||
attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
|
attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
|
||||||
self.conditioning_conv = nn.Conv1d(embedding_dim*3, embedding_dim, 1)
|
self.conditioning_conv = nn.Conv1d(embedding_dim*3, embedding_dim, 1)
|
||||||
|
|
||||||
self.enable_unaligned_inputs = enabled_unaligned_inputs
|
self.enable_unaligned_inputs = enabled_unaligned_inputs
|
||||||
if enabled_unaligned_inputs:
|
if enabled_unaligned_inputs:
|
||||||
self.unaligned_embedder = nn.Embedding(num_unaligned_tokens, embedding_dim)
|
# nn.Embedding
|
||||||
|
self.unaligned_embedder = ml.Embedding(num_unaligned_tokens, embedding_dim)
|
||||||
self.unaligned_encoder = CheckpointedXTransformerEncoder(
|
self.unaligned_encoder = CheckpointedXTransformerEncoder(
|
||||||
max_seq_len=-1,
|
max_seq_len=-1,
|
||||||
use_pos_emb=False,
|
use_pos_emb=False,
|
||||||
|
|
|
@ -5,6 +5,7 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import autocast
|
from torch import autocast
|
||||||
from x_transformers import Encoder
|
from x_transformers import Encoder
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||||
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \
|
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \
|
||||||
|
@ -206,7 +207,8 @@ class DiffusionTts(nn.Module):
|
||||||
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
|
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
|
||||||
# transformer network.
|
# transformer network.
|
||||||
self.code_converter = nn.Sequential(
|
self.code_converter = nn.Sequential(
|
||||||
nn.Embedding(in_tokens, conditioning_dim),
|
# nn.Embedding
|
||||||
|
ml.Embedding(in_tokens, conditioning_dim),
|
||||||
CheckpointedXTransformerEncoder(
|
CheckpointedXTransformerEncoder(
|
||||||
needs_permute=False,
|
needs_permute=False,
|
||||||
max_seq_len=-1,
|
max_seq_len=-1,
|
||||||
|
|
|
@ -5,6 +5,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import autocast
|
from torch import autocast
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||||
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock, QKVAttentionLegacy
|
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock, QKVAttentionLegacy
|
||||||
|
@ -193,7 +194,9 @@ class DiffusionTtsFlat(nn.Module):
|
||||||
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
|
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
|
||||||
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
|
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
|
||||||
# transformer network.
|
# transformer network.
|
||||||
self.code_embedding = nn.Embedding(in_tokens, model_channels)
|
|
||||||
|
# nn.Embedding
|
||||||
|
self.code_embedding = ml.Embedding(in_tokens, model_channels)
|
||||||
self.code_converter = nn.Sequential(
|
self.code_converter = nn.Sequential(
|
||||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from transformers import GPT2Config, GPT2PreTrainedModel
|
from transformers import GPT2Config, GPT2PreTrainedModel
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
|
||||||
|
@ -12,6 +13,7 @@ from models.lucidrains.x_transformers import RotaryEmbedding, apply_rotary_pos_e
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
class ResBlock(nn.Module):
|
class ResBlock(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
@ -279,9 +281,11 @@ class UnifiedVoice(nn.Module):
|
||||||
self.mel_length_compression = mel_length_compression
|
self.mel_length_compression = mel_length_compression
|
||||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
||||||
self.average_conditioning_embeddings = average_conditioning_embeddings
|
self.average_conditioning_embeddings = average_conditioning_embeddings
|
||||||
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
# nn.Embedding
|
||||||
|
self.text_embedding = ml.Embedding(self.number_text_tokens, model_dim)
|
||||||
if use_mel_codes_as_input:
|
if use_mel_codes_as_input:
|
||||||
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
|
# nn.Embedding
|
||||||
|
self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim)
|
||||||
else:
|
else:
|
||||||
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
|
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
|
||||||
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
|
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
|
||||||
|
@ -294,8 +298,8 @@ class UnifiedVoice(nn.Module):
|
||||||
self.text_solo_embedding = 0
|
self.text_solo_embedding = 0
|
||||||
|
|
||||||
self.final_norm = nn.LayerNorm(model_dim)
|
self.final_norm = nn.LayerNorm(model_dim)
|
||||||
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
|
self.text_head = ml.Linear(model_dim, self.number_text_tokens)
|
||||||
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
|
self.mel_head = ml.Linear(model_dim, self.number_mel_codes)
|
||||||
|
|
||||||
# Initialize the embeddings per the GPT-2 scheme
|
# Initialize the embeddings per the GPT-2 scheme
|
||||||
embeddings = [self.text_embedding]
|
embeddings = [self.text_embedding]
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
from transformers import GPT2Config, GPT2PreTrainedModel
|
from transformers import GPT2Config, GPT2PreTrainedModel
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
|
||||||
|
@ -271,15 +273,17 @@ class UnifiedVoice(nn.Module):
|
||||||
self.model_dim = model_dim
|
self.model_dim = model_dim
|
||||||
self.mel_length_compression = mel_length_compression
|
self.mel_length_compression = mel_length_compression
|
||||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
||||||
self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim)
|
# nn.Embedding
|
||||||
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
|
self.text_embedding = ml.Embedding(self.number_text_tokens*types+1, model_dim)
|
||||||
|
# nn.Embedding
|
||||||
|
self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim)
|
||||||
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
|
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
|
||||||
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing)
|
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing)
|
||||||
|
|
||||||
self.final_norm = nn.LayerNorm(model_dim)
|
self.final_norm = nn.LayerNorm(model_dim)
|
||||||
self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1)
|
self.text_head = ml.Linear(model_dim, self.number_text_tokens*types+1)
|
||||||
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
|
self.mel_head = ml.Linear(model_dim, self.number_mel_codes)
|
||||||
self.aligned_head = nn.Linear(model_dim, number_aligned_text_codes)
|
self.aligned_head = ml.Linear(model_dim, number_aligned_text_codes)
|
||||||
|
|
||||||
# Initialize the embeddings per the GPT-2 scheme
|
# Initialize the embeddings per the GPT-2 scheme
|
||||||
embeddings = [self.text_embedding, self.mel_embedding]
|
embeddings = [self.text_embedding, self.mel_embedding]
|
||||||
|
|
|
@ -11,6 +11,7 @@ from models.audio.tts.transformer_builders import build_hf_gpt_transformer
|
||||||
from models.lucidrains.x_transformers import RotaryEmbedding, apply_rotary_pos_emb
|
from models.lucidrains.x_transformers import RotaryEmbedding, apply_rotary_pos_emb
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
|
|
||||||
class ResBlock(nn.Module):
|
class ResBlock(nn.Module):
|
||||||
|
@ -255,15 +256,17 @@ class UnifiedVoice(nn.Module):
|
||||||
self.model_dim = model_dim
|
self.model_dim = model_dim
|
||||||
self.mel_length_compression = mel_length_compression
|
self.mel_length_compression = mel_length_compression
|
||||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
||||||
self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim)
|
# nn.Embedding
|
||||||
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
|
self.text_embedding = ml.Embedding(self.number_text_tokens*types+1, model_dim)
|
||||||
|
# nn.Embedding
|
||||||
|
self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim)
|
||||||
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
|
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
|
||||||
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing)
|
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing)
|
||||||
|
|
||||||
self.final_norm = nn.LayerNorm(model_dim)
|
self.final_norm = nn.LayerNorm(model_dim)
|
||||||
self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1)
|
self.text_head = ml.Linear(model_dim, self.number_text_tokens*types+1)
|
||||||
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
|
self.mel_head = ml.Linear(model_dim, self.number_mel_codes)
|
||||||
self.alignment_head = nn.Linear(model_dim, 256)
|
self.alignment_head = ml.Linear(model_dim, 256)
|
||||||
|
|
||||||
if only_alignment_head:
|
if only_alignment_head:
|
||||||
for p in self.parameters():
|
for p in self.parameters():
|
||||||
|
|
|
@ -8,6 +8,7 @@ from models.audio.tts.mini_encoder import AudioMiniEncoder
|
||||||
from trainer.injectors.spec_augment import spec_augment
|
from trainer.injectors.spec_augment import spec_augment
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
|
@ -36,7 +37,7 @@ class VoiceCLIP(nn.Module):
|
||||||
self.encoder = AudioMiniEncoder(80, encoder_output)
|
self.encoder = AudioMiniEncoder(80, encoder_output)
|
||||||
if pretrained_encoder_dict_path is not None:
|
if pretrained_encoder_dict_path is not None:
|
||||||
self.encoder.load_state_dict(torch.load(pretrained_encoder_dict_path))
|
self.encoder.load_state_dict(torch.load(pretrained_encoder_dict_path))
|
||||||
self.to_latent = nn.Linear(encoder_output, dim_latent, bias=False)
|
self.to_latent = ml.Linear(encoder_output, dim_latent, bias=False)
|
||||||
self.temperature = nn.Parameter(torch.tensor(1.))
|
self.temperature = nn.Parameter(torch.tensor(1.))
|
||||||
self.mel_compression_ratio = mel_compression_ratio
|
self.mel_compression_ratio = mel_compression_ratio
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ from x_transformers import Encoder, Decoder, ContinuousTransformerWrapper
|
||||||
|
|
||||||
from models.audio.tts.mini_encoder import AudioMiniEncoder
|
from models.audio.tts.mini_encoder import AudioMiniEncoder
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
|
|
||||||
class CheckpointedLayer(nn.Module):
|
class CheckpointedLayer(nn.Module):
|
||||||
|
@ -56,7 +57,8 @@ class Wav2VecMatcher(nn.Module):
|
||||||
WAV2VEC_CHANNELS = 1024
|
WAV2VEC_CHANNELS = 1024
|
||||||
self.conditioning_encoder = AudioMiniEncoder(1, model_dim, base_channels=32, depth=6, resnet_blocks=1,
|
self.conditioning_encoder = AudioMiniEncoder(1, model_dim, base_channels=32, depth=6, resnet_blocks=1,
|
||||||
attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
|
attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
|
||||||
self.text_embedding = nn.Embedding(num_text_tokens, model_dim)
|
# nn.Embedding
|
||||||
|
self.text_embedding = ml.Embedding(num_text_tokens, model_dim)
|
||||||
self.encoder = CheckpointedXTransformer(
|
self.encoder = CheckpointedXTransformer(
|
||||||
max_seq_len=-1,
|
max_seq_len=-1,
|
||||||
use_pos_emb=False,
|
use_pos_emb=False,
|
||||||
|
@ -73,8 +75,8 @@ class Wav2VecMatcher(nn.Module):
|
||||||
)
|
)
|
||||||
self.decoder_start_embedding = nn.Parameter(torch.randn(1,1,model_dim))
|
self.decoder_start_embedding = nn.Parameter(torch.randn(1,1,model_dim))
|
||||||
self.decoder_stop_embedding = nn.Parameter(torch.randn(1,model_dim))
|
self.decoder_stop_embedding = nn.Parameter(torch.randn(1,model_dim))
|
||||||
self.w2v_query_encoder = nn.Linear(WAV2VEC_CHANNELS, model_dim)
|
self.w2v_query_encoder = ml.Linear(WAV2VEC_CHANNELS, model_dim)
|
||||||
self.w2v_value_encoder = nn.Linear(WAV2VEC_CHANNELS, model_dim)
|
self.w2v_value_encoder = ml.Linear(WAV2VEC_CHANNELS, model_dim)
|
||||||
self.decoder = CheckpointedXTransformer(
|
self.decoder = CheckpointedXTransformer(
|
||||||
max_seq_len=-1, # Should be unused
|
max_seq_len=-1, # Should be unused
|
||||||
use_pos_emb=False,
|
use_pos_emb=False,
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
|
|
||||||
|
@ -98,7 +99,7 @@ class ResNet(nn.Module):
|
||||||
self.conv4_x = self._make_layer(block, 128, num_block[2], 2)
|
self.conv4_x = self._make_layer(block, 128, num_block[2], 2)
|
||||||
self.conv5_x = self._make_layer(block, 256, num_block[3], 2)
|
self.conv5_x = self._make_layer(block, 256, num_block[3], 2)
|
||||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||||
self.fc = nn.Linear(256 * block.expansion, num_classes)
|
self.fc = ml.Linear(256 * block.expansion, num_classes)
|
||||||
|
|
||||||
def _make_layer(self, block, out_channels, num_blocks, stride):
|
def _make_layer(self, block, out_channels, num_blocks, stride):
|
||||||
"""make resnet layers(by layer i didnt mean this 'layer' was the
|
"""make resnet layers(by layer i didnt mean this 'layer' was the
|
||||||
|
|
|
@ -4,6 +4,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torchvision.models.resnet import BasicBlock, Bottleneck
|
from torchvision.models.resnet import BasicBlock, Bottleneck
|
||||||
import torchvision
|
import torchvision
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
||||||
|
@ -194,5 +195,5 @@ def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
|
||||||
def register_resnet50(opt_net, opt):
|
def register_resnet50(opt_net, opt):
|
||||||
model = resnet50(pretrained=opt_net['pretrained'])
|
model = resnet50(pretrained=opt_net['pretrained'])
|
||||||
if opt_net['custom_head_logits']:
|
if opt_net['custom_head_logits']:
|
||||||
model.fc = nn.Linear(512 * 4, opt_net['custom_head_logits'])
|
model.fc = ml.Linear(512 * 4, opt_net['custom_head_logits'])
|
||||||
return model
|
return model
|
||||||
|
|
|
@ -11,6 +11,7 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
|
|
||||||
|
@ -101,7 +102,7 @@ class ResNet(nn.Module):
|
||||||
self.conv4_x = self._make_layer(block, 128, num_block[2], 2)
|
self.conv4_x = self._make_layer(block, 128, num_block[2], 2)
|
||||||
self.conv5_x = self._make_layer(block, 256, num_block[3], 2)
|
self.conv5_x = self._make_layer(block, 256, num_block[3], 2)
|
||||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||||
self.fc = nn.Linear(256 * block.expansion, num_classes)
|
self.fc = ml.Linear(256 * block.expansion, num_classes)
|
||||||
|
|
||||||
def _make_layer(self, block, out_channels, num_blocks, stride):
|
def _make_layer(self, block, out_channels, num_blocks, stride):
|
||||||
"""make resnet layers(by layer i didnt mean this 'layer' was the
|
"""make resnet layers(by layer i didnt mean this 'layer' was the
|
||||||
|
|
|
@ -11,6 +11,7 @@ __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
||||||
from models.vqvae.scaled_weight_conv import ScaledWeightConv
|
from models.vqvae.scaled_weight_conv import ScaledWeightConv
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import checkpoint
|
from utils.util import checkpoint
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
model_urls = {
|
model_urls = {
|
||||||
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||||
|
@ -213,7 +214,7 @@ class ResNet(nn.Module):
|
||||||
self.layer4 = self._make_layer(block, 512, layers[3], breadth, stride=2,
|
self.layer4 = self._make_layer(block, 512, layers[3], breadth, stride=2,
|
||||||
dilate=replace_stride_with_dilation[2])
|
dilate=replace_stride_with_dilation[2])
|
||||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
self.fc = ml.Linear(512 * block.expansion, num_classes)
|
||||||
|
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, ScaledWeightConv):
|
if isinstance(m, ScaledWeightConv):
|
||||||
|
|
|
@ -3,7 +3,7 @@ import torch.nn as nn
|
||||||
|
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
class WideKernelVgg(nn.Module):
|
class WideKernelVgg(nn.Module):
|
||||||
def __init__(self, nf=64, num_classes=2):
|
def __init__(self, nf=64, num_classes=2):
|
||||||
|
@ -49,9 +49,9 @@ class WideKernelVgg(nn.Module):
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.MaxPool2d(kernel_size=2),
|
nn.MaxPool2d(kernel_size=2),
|
||||||
nn.Flatten(),
|
nn.Flatten(),
|
||||||
nn.Linear(nf * 8 * 4 * 2, 100),
|
ml.Linear(nf * 8 * 4 * 2, 100),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Linear(100, num_classes)
|
ml.Linear(100, num_classes)
|
||||||
)
|
)
|
||||||
|
|
||||||
# These normalization constants should be derived experimentally.
|
# These normalization constants should be derived experimentally.
|
||||||
|
|
|
@ -10,6 +10,7 @@ from models.arch_util import AttentionBlock
|
||||||
from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder
|
from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import opt_get, checkpoint
|
from utils.util import opt_get, checkpoint
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
|
@ -58,7 +59,8 @@ class CollapsingTransformer(nn.Module):
|
||||||
class ConvFormatEmbedding(nn.Module):
|
class ConvFormatEmbedding(nn.Module):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.emb = nn.Embedding(*args, **kwargs)
|
# nn.Embedding
|
||||||
|
self.emb = ml.Embedding(*args, **kwargs)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
y = self.emb(x)
|
y = self.emb(x)
|
||||||
|
@ -98,9 +100,10 @@ class CLVP(nn.Module):
|
||||||
self.masked_conditioning_latent = nn.Parameter(torch.randn(1,model_dim*2), requires_grad=True)
|
self.masked_conditioning_latent = nn.Parameter(torch.randn(1,model_dim*2), requires_grad=True)
|
||||||
self.mask_conditioning_percentage = mask_conditioning_percentage
|
self.mask_conditioning_percentage = mask_conditioning_percentage
|
||||||
|
|
||||||
self.text_emb = nn.Embedding(num_text_tokens, model_dim)
|
# nn.Embedding
|
||||||
|
self.text_emb = ml.Embedding(num_text_tokens, model_dim)
|
||||||
self.text_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, text_enc_depth, text_mask_percentage, use_rms_scaleshift_norm=True)
|
self.text_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, text_enc_depth, text_mask_percentage, use_rms_scaleshift_norm=True)
|
||||||
self.to_text_latent = nn.Linear(latent_dim, latent_dim, bias=False)
|
self.to_text_latent = ml.Linear(latent_dim, latent_dim, bias=False)
|
||||||
self.distributed_collect = distributed_collect
|
self.distributed_collect = distributed_collect
|
||||||
|
|
||||||
if mel_codes is None:
|
if mel_codes is None:
|
||||||
|
@ -108,7 +111,7 @@ class CLVP(nn.Module):
|
||||||
else:
|
else:
|
||||||
self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim)
|
self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim)
|
||||||
self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage)
|
self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage)
|
||||||
self.to_speech_latent = nn.Linear(latent_dim, latent_dim, bias=False)
|
self.to_speech_latent = ml.Linear(latent_dim, latent_dim, bias=False)
|
||||||
|
|
||||||
def get_grad_norm_parameter_groups(self):
|
def get_grad_norm_parameter_groups(self):
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -9,6 +9,7 @@ from models.arch_util import AttentionBlock
|
||||||
from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder
|
from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import opt_get, checkpoint
|
from utils.util import opt_get, checkpoint
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
|
@ -178,7 +179,8 @@ class CollapsingTransformer(nn.Module):
|
||||||
class ConvFormatEmbedding(nn.Module):
|
class ConvFormatEmbedding(nn.Module):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.emb = nn.Embedding(*args, **kwargs)
|
# nn.Embedding
|
||||||
|
self.emb = ml.Embedding(*args, **kwargs)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
y = self.emb(x)
|
y = self.emb(x)
|
||||||
|
@ -203,8 +205,8 @@ class ContrastiveAudio(nn.Module):
|
||||||
self.emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim // 2, kernel_size=5, stride=2, padding=2),
|
self.emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim // 2, kernel_size=5, stride=2, padding=2),
|
||||||
nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1))
|
nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1))
|
||||||
self.transformer = CollapsingTransformer(model_dim, model_dim, transformer_heads, dropout, encoder_depth, mask_percent)
|
self.transformer = CollapsingTransformer(model_dim, model_dim, transformer_heads, dropout, encoder_depth, mask_percent)
|
||||||
self.to_latent = nn.Linear(latent_dim, latent_dim, bias=False)
|
self.to_latent = ml.Linear(latent_dim, latent_dim, bias=False)
|
||||||
self.to_latent2 = nn.Linear(latent_dim, latent_dim, bias=False)
|
self.to_latent2 = ml.Linear(latent_dim, latent_dim, bias=False)
|
||||||
|
|
||||||
self.to_latent2.weight.data = self.to_latent.weight.data
|
self.to_latent2.weight.data = self.to_latent.weight.data
|
||||||
self.to_latent2.weight.DO_NOT_TRAIN = True
|
self.to_latent2.weight.DO_NOT_TRAIN = True
|
||||||
|
|
|
@ -10,6 +10,7 @@ from models.arch_util import AttentionBlock
|
||||||
from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder
|
from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import opt_get, checkpoint
|
from utils.util import opt_get, checkpoint
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
|
@ -58,7 +59,8 @@ class CollapsingTransformer(nn.Module):
|
||||||
class ConvFormatEmbedding(nn.Module):
|
class ConvFormatEmbedding(nn.Module):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.emb = nn.Embedding(*args, **kwargs)
|
# nn.Embedding
|
||||||
|
self.emb = ml.Embedding(*args, **kwargs)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
y = self.emb(x)
|
y = self.emb(x)
|
||||||
|
@ -86,14 +88,14 @@ class CVVP(nn.Module):
|
||||||
self.cond_emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim//2, kernel_size=5, stride=2, padding=2),
|
self.cond_emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim//2, kernel_size=5, stride=2, padding=2),
|
||||||
nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1))
|
nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1))
|
||||||
self.conditioning_transformer = CollapsingTransformer(model_dim, model_dim, transformer_heads, dropout, conditioning_enc_depth, cond_mask_percentage)
|
self.conditioning_transformer = CollapsingTransformer(model_dim, model_dim, transformer_heads, dropout, conditioning_enc_depth, cond_mask_percentage)
|
||||||
self.to_conditioning_latent = nn.Linear(latent_dim, latent_dim, bias=False)
|
self.to_conditioning_latent = ml.Linear(latent_dim, latent_dim, bias=False)
|
||||||
|
|
||||||
if mel_codes is None:
|
if mel_codes is None:
|
||||||
self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2)
|
self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2)
|
||||||
else:
|
else:
|
||||||
self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim)
|
self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim)
|
||||||
self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage)
|
self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage)
|
||||||
self.to_speech_latent = nn.Linear(latent_dim, latent_dim, bias=False)
|
self.to_speech_latent = ml.Linear(latent_dim, latent_dim, bias=False)
|
||||||
|
|
||||||
def get_grad_norm_parameter_groups(self):
|
def get_grad_norm_parameter_groups(self):
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -7,6 +7,7 @@ from torch import einsum
|
||||||
from models.lucidrains.dalle.transformer import Transformer
|
from models.lucidrains.dalle.transformer import Transformer
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
|
@ -45,17 +46,20 @@ class MelTextCLIP(nn.Module):
|
||||||
mel_compression=256,
|
mel_compression=256,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.text_emb = nn.Embedding(num_text_tokens, dim_text)
|
# nn.Embedding
|
||||||
self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
|
self.text_emb = ml.Embedding(num_text_tokens, dim_text)
|
||||||
|
# nn.Embedding
|
||||||
|
self.text_pos_emb = ml.Embedding(text_seq_len, dim_text)
|
||||||
self.text_transformer = Transformer(causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth,
|
self.text_transformer = Transformer(causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth,
|
||||||
heads=text_heads, rotary_emb=False)
|
heads=text_heads, rotary_emb=False)
|
||||||
self.to_text_latent = nn.Linear(dim_text, dim_latent, bias=False)
|
self.to_text_latent = ml.Linear(dim_text, dim_latent, bias=False)
|
||||||
|
|
||||||
self.speech_enc = nn.Conv1d(80, dim_speech, kernel_size=3, padding=1)
|
self.speech_enc = nn.Conv1d(80, dim_speech, kernel_size=3, padding=1)
|
||||||
self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
|
# nn.Embedding
|
||||||
|
self.speech_pos_emb = ml.Embedding(num_speech_tokens, dim_speech)
|
||||||
self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech,
|
self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech,
|
||||||
depth=speech_enc_depth, heads=speech_heads, rotary_emb=False)
|
depth=speech_enc_depth, heads=speech_heads, rotary_emb=False)
|
||||||
self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False)
|
self.to_speech_latent = ml.Linear(dim_speech, dim_latent, bias=False)
|
||||||
|
|
||||||
self.temperature = nn.Parameter(torch.tensor(1.))
|
self.temperature = nn.Parameter(torch.tensor(1.))
|
||||||
self.text_mask_percentage = text_mask_percentage
|
self.text_mask_percentage = text_mask_percentage
|
||||||
|
|
|
@ -7,6 +7,7 @@ from models.audio.tts.unified_voice2 import ConditioningEncoder
|
||||||
from models.lucidrains.dalle.transformer import Transformer
|
from models.lucidrains.dalle.transformer import Transformer
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
|
@ -45,7 +46,7 @@ class VoiceCondCLIP(nn.Module):
|
||||||
self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
|
self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
|
||||||
self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech,
|
self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech,
|
||||||
depth=speech_enc_depth, heads=speech_heads, rotary_emb=False)
|
depth=speech_enc_depth, heads=speech_heads, rotary_emb=False)
|
||||||
self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False)
|
self.to_speech_latent = ml.Linear(dim_speech, dim_latent, bias=False)
|
||||||
|
|
||||||
self.temperature = nn.Parameter(torch.tensor(1.))
|
self.temperature = nn.Parameter(torch.tensor(1.))
|
||||||
self.voice_mask_percentage = voice_mask_percentage
|
self.voice_mask_percentage = voice_mask_percentage
|
||||||
|
|
|
@ -11,6 +11,7 @@ from models.audio.tts.unet_diffusion_tts7 import CheckpointedXTransformerEncoder
|
||||||
from models.lucidrains.dalle.transformer import Transformer
|
from models.lucidrains.dalle.transformer import Transformer
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
|
@ -53,11 +54,13 @@ class VoiceCLIP(nn.Module):
|
||||||
distributed_collect=False,
|
distributed_collect=False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.text_emb = nn.Embedding(num_text_tokens, dim_text)
|
# nn.Embedding
|
||||||
self.to_text_latent = nn.Linear(dim_text, dim_latent, bias=False)
|
self.text_emb = ml.Embedding(num_text_tokens, dim_text)
|
||||||
|
self.to_text_latent = ml.Linear(dim_text, dim_latent, bias=False)
|
||||||
|
|
||||||
self.speech_emb = nn.Embedding(num_speech_tokens, dim_speech)
|
# nn.Embedding
|
||||||
self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False)
|
self.speech_emb = ml.Embedding(num_speech_tokens, dim_speech)
|
||||||
|
self.to_speech_latent = ml.Linear(dim_speech, dim_latent, bias=False)
|
||||||
|
|
||||||
if use_xformers:
|
if use_xformers:
|
||||||
self.text_transformer = CheckpointedXTransformerEncoder(
|
self.text_transformer = CheckpointedXTransformerEncoder(
|
||||||
|
@ -105,8 +108,10 @@ class VoiceCLIP(nn.Module):
|
||||||
self.min_mel_size = min_mel_size
|
self.min_mel_size = min_mel_size
|
||||||
self.distributed_collect = distributed_collect
|
self.distributed_collect = distributed_collect
|
||||||
if not use_xformers:
|
if not use_xformers:
|
||||||
self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
|
# nn.Embedding
|
||||||
self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
|
self.text_pos_emb = ml.Embedding(text_seq_len, dim_text)
|
||||||
|
# nn.Embedding
|
||||||
|
self.speech_pos_emb = ml.Embedding(num_speech_tokens, dim_speech)
|
||||||
|
|
||||||
def embed_text(self, text):
|
def embed_text(self, text):
|
||||||
text_mask = torch.ones_like(text.float()).bool()
|
text_mask = torch.ones_like(text.float()).bool()
|
||||||
|
|
|
@ -6,6 +6,7 @@ import math
|
||||||
|
|
||||||
import torch as th
|
import torch as th
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
|
|
||||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||||
|
@ -36,7 +37,7 @@ def linear(*args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Create a linear module.
|
Create a linear module.
|
||||||
"""
|
"""
|
||||||
return nn.Linear(*args, **kwargs)
|
return ml.Linear(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def avg_pool_nd(dims, *args, **kwargs):
|
def avg_pool_nd(dims, *args, **kwargs):
|
||||||
|
|
|
@ -6,6 +6,7 @@ from models.arch_util import ConvGnLelu, default_init_weights, make_layer
|
||||||
from models.diffusion.nn import timestep_embedding
|
from models.diffusion.nn import timestep_embedding
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import checkpoint
|
from utils.util import checkpoint
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
|
|
||||||
# Conditionally uses torch's checkpoint functionality if it is enabled in the opt file.
|
# Conditionally uses torch's checkpoint functionality if it is enabled in the opt file.
|
||||||
|
@ -28,7 +29,7 @@ class ResidualDenseBlock(nn.Module):
|
||||||
self.first_conv = ConvGnLelu(mid_channels, mid_channels, activation=True, norm=False, bias=True)
|
self.first_conv = ConvGnLelu(mid_channels, mid_channels, activation=True, norm=False, bias=True)
|
||||||
self.emb_layers = nn.Sequential(
|
self.emb_layers = nn.Sequential(
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Linear(
|
ml.Linear(
|
||||||
mid_channels*4,
|
mid_channels*4,
|
||||||
mid_channels,
|
mid_channels,
|
||||||
),
|
),
|
||||||
|
@ -143,9 +144,9 @@ class RRDBNet(nn.Module):
|
||||||
# Guided diffusion uses a time embedding.
|
# Guided diffusion uses a time embedding.
|
||||||
time_embed_dim = mid_channels * 4
|
time_embed_dim = mid_channels * 4
|
||||||
self.time_embed = nn.Sequential(
|
self.time_embed = nn.Sequential(
|
||||||
nn.Linear(mid_channels, time_embed_dim),
|
ml.Linear(mid_channels, time_embed_dim),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Linear(time_embed_dim, time_embed_dim),
|
ml.Linear(time_embed_dim, time_embed_dim),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.body = make_layer(
|
self.body = make_layer(
|
||||||
|
|
|
@ -20,6 +20,7 @@ from models.diffusion.nn import (
|
||||||
)
|
)
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import checkpoint
|
from utils.util import checkpoint
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
|
|
||||||
class AttentionPool2d(nn.Module):
|
class AttentionPool2d(nn.Module):
|
||||||
|
@ -515,7 +516,8 @@ class UNetModel(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.num_classes is not None:
|
if self.num_classes is not None:
|
||||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
# nn.Embedding
|
||||||
|
self.label_emb = ml.Embedding(num_classes, time_embed_dim)
|
||||||
self.use_raw_y_as_embedding = use_raw_y_as_embedding
|
self.use_raw_y_as_embedding = use_raw_y_as_embedding
|
||||||
assert not ((self.num_classes is not None) and use_raw_y_as_embedding) # These are mutually-exclusive.
|
assert not ((self.num_classes is not None) and use_raw_y_as_embedding) # These are mutually-exclusive.
|
||||||
|
|
||||||
|
@ -867,16 +869,16 @@ class EncoderUNetModel(nn.Module):
|
||||||
)
|
)
|
||||||
elif pool == "spatial":
|
elif pool == "spatial":
|
||||||
self.out = nn.Sequential(
|
self.out = nn.Sequential(
|
||||||
nn.Linear(self._feature_size, 2048),
|
ml.Linear(self._feature_size, 2048),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Linear(2048, self.out_channels),
|
ml.Linear(2048, self.out_channels),
|
||||||
)
|
)
|
||||||
elif pool == "spatial_v2":
|
elif pool == "spatial_v2":
|
||||||
self.out = nn.Sequential(
|
self.out = nn.Sequential(
|
||||||
nn.Linear(self._feature_size, 2048),
|
ml.Linear(self._feature_size, 2048),
|
||||||
normalization(2048),
|
normalization(2048),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Linear(2048, self.out_channels),
|
ml.Linear(2048, self.out_channels),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unexpected {pool} pooling")
|
raise NotImplementedError(f"Unexpected {pool} pooling")
|
||||||
|
|
|
@ -26,6 +26,7 @@ from models.diffusion.nn import (
|
||||||
)
|
)
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import checkpoint
|
from utils.util import checkpoint
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
|
|
||||||
class AttentionPool2d(nn.Module):
|
class AttentionPool2d(nn.Module):
|
||||||
|
@ -476,7 +477,8 @@ class UNetModel(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.num_classes is not None:
|
if self.num_classes is not None:
|
||||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
# nn.Embedding
|
||||||
|
self.label_emb = ml.Embedding(num_classes, time_embed_dim)
|
||||||
|
|
||||||
self.input_blocks = nn.ModuleList(
|
self.input_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
|
@ -736,7 +738,7 @@ class ResNetEncoder(nn.Module):
|
||||||
dilate=replace_stride_with_dilation[2])
|
dilate=replace_stride_with_dilation[2])
|
||||||
f=512
|
f=512
|
||||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||||
self.fc = nn.Linear(f * block.expansion, output_dim)
|
self.fc = ml.Linear(f * block.expansion, output_dim)
|
||||||
|
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, nn.Conv2d):
|
if isinstance(m, nn.Conv2d):
|
||||||
|
|
|
@ -5,6 +5,7 @@ import torch.nn.functional as F
|
||||||
|
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import checkpoint, opt_get
|
from utils.util import checkpoint, opt_get
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
|
|
||||||
class Discriminator_VGG_128(nn.Module):
|
class Discriminator_VGG_128(nn.Module):
|
||||||
|
@ -46,8 +47,8 @@ class Discriminator_VGG_128(nn.Module):
|
||||||
input_img_factor = input_img_factor // 2
|
input_img_factor = input_img_factor // 2
|
||||||
final_nf = nf * 16
|
final_nf = nf * 16
|
||||||
|
|
||||||
self.linear1 = nn.Linear(final_nf * 4 * input_img_factor * 4 * input_img_factor, 100)
|
self.linear1 = ml.Linear(final_nf * 4 * input_img_factor * 4 * input_img_factor, 100)
|
||||||
self.linear2 = nn.Linear(100, 1)
|
self.linear2 = ml.Linear(100, 1)
|
||||||
|
|
||||||
# activation function
|
# activation function
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
|
@ -129,8 +130,8 @@ class Discriminator_VGG_128_GN(nn.Module):
|
||||||
# activation function
|
# activation function
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
|
|
||||||
self.linear1 = nn.Linear(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 100)
|
self.linear1 = ml.Linear(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 100)
|
||||||
self.linear2 = nn.Linear(100, 1)
|
self.linear2 = ml.Linear(100, 1)
|
||||||
|
|
||||||
def compute_body(self, x):
|
def compute_body(self, x):
|
||||||
fea = self.lrelu(self.conv0_0(x))
|
fea = self.lrelu(self.conv0_0(x))
|
||||||
|
@ -219,8 +220,8 @@ class DiscriminatorVGG448GN(nn.Module):
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
|
|
||||||
final_nf = nf * 8
|
final_nf = nf * 8
|
||||||
self.linear1 = nn.Linear(int(final_nf * 7 * 7), 100)
|
self.linear1 = ml.Linear(int(final_nf * 7 * 7), 100)
|
||||||
self.linear2 = nn.Linear(100, 1)
|
self.linear2 = ml.Linear(100, 1)
|
||||||
|
|
||||||
# Assign all new heads to the new param group.2
|
# Assign all new heads to the new param group.2
|
||||||
for m in [self.convn1_0, self.convn1_1, self.bnn1_1, self.conv0_0_new, self.bn0_0]:
|
for m in [self.convn1_0, self.convn1_1, self.bnn1_1, self.conv0_0_new, self.bn0_0]:
|
||||||
|
|
|
@ -2,6 +2,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.init as init
|
import torch.nn.init as init
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
|
|
||||||
def initialize_weights(net_l, scale=1):
|
def initialize_weights(net_l, scale=1):
|
||||||
|
@ -14,7 +15,7 @@ def initialize_weights(net_l, scale=1):
|
||||||
m.weight.data *= scale # for residual block
|
m.weight.data *= scale # for residual block
|
||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
m.bias.data.zero_()
|
m.bias.data.zero_()
|
||||||
elif isinstance(m, nn.Linear):
|
elif isinstance(m, ml.Linear):
|
||||||
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||||
m.weight.data *= scale
|
m.weight.data *= scale
|
||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
|
|
|
@ -28,6 +28,7 @@ except:
|
||||||
APEX_AVAILABLE = False
|
APEX_AVAILABLE = False
|
||||||
|
|
||||||
assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.'
|
assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.'
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
num_cores = multiprocessing.cpu_count()
|
num_cores = multiprocessing.cpu_count()
|
||||||
|
|
||||||
|
@ -351,7 +352,7 @@ class RGBBlock(nn.Module):
|
||||||
def __init__(self, latent_dim, input_channel, upsample, rgba=False):
|
def __init__(self, latent_dim, input_channel, upsample, rgba=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_channel = input_channel
|
self.input_channel = input_channel
|
||||||
self.to_style = nn.Linear(latent_dim, input_channel)
|
self.to_style = ml.Linear(latent_dim, input_channel)
|
||||||
|
|
||||||
out_filters = 3 if not rgba else 4
|
out_filters = 3 if not rgba else 4
|
||||||
self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False)
|
self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False)
|
||||||
|
@ -489,16 +490,16 @@ class GeneratorBlockWithStructure(nn.Module):
|
||||||
|
|
||||||
# Uses stylegan1 style blocks for injecting structural latent.
|
# Uses stylegan1 style blocks for injecting structural latent.
|
||||||
self.conv0 = EqualConv2d(input_channels, filters, 3, padding=1)
|
self.conv0 = EqualConv2d(input_channels, filters, 3, padding=1)
|
||||||
self.to_noise0 = nn.Linear(1, filters)
|
self.to_noise0 = ml.Linear(1, filters)
|
||||||
self.noise0 = equal_lr(NoiseInjection(filters))
|
self.noise0 = equal_lr(NoiseInjection(filters))
|
||||||
self.adain0 = AdaptiveInstanceNorm(filters, latent_dim)
|
self.adain0 = AdaptiveInstanceNorm(filters, latent_dim)
|
||||||
|
|
||||||
self.to_style1 = nn.Linear(latent_dim, filters)
|
self.to_style1 = ml.Linear(latent_dim, filters)
|
||||||
self.to_noise1 = nn.Linear(1, filters)
|
self.to_noise1 = ml.Linear(1, filters)
|
||||||
self.conv1 = Conv2DMod(filters, filters, 3)
|
self.conv1 = Conv2DMod(filters, filters, 3)
|
||||||
|
|
||||||
self.to_style2 = nn.Linear(latent_dim, filters)
|
self.to_style2 = ml.Linear(latent_dim, filters)
|
||||||
self.to_noise2 = nn.Linear(1, filters)
|
self.to_noise2 = ml.Linear(1, filters)
|
||||||
self.conv2 = Conv2DMod(filters, filters, 3)
|
self.conv2 = Conv2DMod(filters, filters, 3)
|
||||||
|
|
||||||
self.activation = leaky_relu()
|
self.activation = leaky_relu()
|
||||||
|
@ -540,12 +541,12 @@ class GeneratorBlock(nn.Module):
|
||||||
self.structure_conv = nn.Conv2d(3, input_channels, 3, padding=1)
|
self.structure_conv = nn.Conv2d(3, input_channels, 3, padding=1)
|
||||||
input_channels = input_channels * 2
|
input_channels = input_channels * 2
|
||||||
|
|
||||||
self.to_style1 = nn.Linear(latent_dim, input_channels)
|
self.to_style1 = ml.Linear(latent_dim, input_channels)
|
||||||
self.to_noise1 = nn.Linear(1, filters)
|
self.to_noise1 = ml.Linear(1, filters)
|
||||||
self.conv1 = Conv2DMod(input_channels, filters, 3)
|
self.conv1 = Conv2DMod(input_channels, filters, 3)
|
||||||
|
|
||||||
self.to_style2 = nn.Linear(latent_dim, filters)
|
self.to_style2 = ml.Linear(latent_dim, filters)
|
||||||
self.to_noise2 = nn.Linear(1, filters)
|
self.to_noise2 = ml.Linear(1, filters)
|
||||||
self.conv2 = Conv2DMod(filters, filters, 3)
|
self.conv2 = Conv2DMod(filters, filters, 3)
|
||||||
|
|
||||||
self.activation = leaky_relu()
|
self.activation = leaky_relu()
|
||||||
|
@ -724,7 +725,7 @@ class StyleGan2GeneratorWithLatent(nn.Module):
|
||||||
|
|
||||||
def _init_weights(self):
|
def _init_weights(self):
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if type(m) in {nn.Conv2d, nn.Linear} and hasattr(m, 'weight'):
|
if type(m) in {nn.Conv2d, ml.Linear} and hasattr(m, 'weight'):
|
||||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
||||||
|
|
||||||
for block in self.gen.blocks:
|
for block in self.gen.blocks:
|
||||||
|
@ -804,7 +805,7 @@ class StyleGan2Discriminator(nn.Module):
|
||||||
|
|
||||||
self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1)
|
self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1)
|
||||||
self.flatten = Flatten()
|
self.flatten = Flatten()
|
||||||
self.to_logit = nn.Linear(latent_dim, 1)
|
self.to_logit = ml.Linear(latent_dim, 1)
|
||||||
|
|
||||||
self._init_weights()
|
self._init_weights()
|
||||||
|
|
||||||
|
@ -836,7 +837,7 @@ class StyleGan2Discriminator(nn.Module):
|
||||||
|
|
||||||
def _init_weights(self):
|
def _init_weights(self):
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if type(m) in {nn.Conv2d, nn.Linear}:
|
if type(m) in {nn.Conv2d, ml.Linear}:
|
||||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@ from torch import nn
|
||||||
from data.images.byol_attachment import RandomApply
|
from data.images.byol_attachment import RandomApply
|
||||||
from trainer.networks import register_model, create_model
|
from trainer.networks import register_model, create_model
|
||||||
from utils.util import checkpoint, opt_get
|
from utils.util import checkpoint, opt_get
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
|
|
||||||
def default(val, def_val):
|
def default(val, def_val):
|
||||||
|
@ -78,10 +79,10 @@ class MLP(nn.Module):
|
||||||
def __init__(self, dim, projection_size, hidden_size=4096):
|
def __init__(self, dim, projection_size, hidden_size=4096):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.net = nn.Sequential(
|
self.net = nn.Sequential(
|
||||||
nn.Linear(dim, hidden_size),
|
ml.Linear(dim, hidden_size),
|
||||||
nn.BatchNorm1d(hidden_size),
|
nn.BatchNorm1d(hidden_size),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Linear(hidden_size, projection_size)
|
ml.Linear(hidden_size, projection_size)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -103,10 +104,10 @@ class StructuralMLP(nn.Module):
|
||||||
nn.BatchNorm2d(c),
|
nn.BatchNorm2d(c),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Flatten(),
|
nn.Flatten(),
|
||||||
nn.Linear(flattened_dim, hidden_size),
|
ml.Linear(flattened_dim, hidden_size),
|
||||||
nn.BatchNorm1d(hidden_size),
|
nn.BatchNorm1d(hidden_size),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Linear(hidden_size, projection_size)
|
ml.Linear(hidden_size, projection_size)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152']
|
__all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152']
|
||||||
|
@ -108,8 +109,8 @@ class FixupResNet(nn.Module):
|
||||||
self.layer4 = self._make_layer(block, num_filters*8, layers[3], stride=2)
|
self.layer4 = self._make_layer(block, num_filters*8, layers[3], stride=2)
|
||||||
self.bias2 = nn.Parameter(torch.zeros(1))
|
self.bias2 = nn.Parameter(torch.zeros(1))
|
||||||
reduced_img_sz = int(input_img_size / 32)
|
reduced_img_sz = int(input_img_size / 32)
|
||||||
self.fc1 = nn.Linear(num_filters * 8 * reduced_img_sz * reduced_img_sz, 100)
|
self.fc1 = ml.Linear(num_filters * 8 * reduced_img_sz * reduced_img_sz, 100)
|
||||||
self.fc2 = nn.Linear(100, num_classes)
|
self.fc2 = ml.Linear(100, num_classes)
|
||||||
|
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, FixupBasicBlock):
|
if isinstance(m, FixupBasicBlock):
|
||||||
|
@ -124,7 +125,7 @@ class FixupResNet(nn.Module):
|
||||||
if m.downsample is not None:
|
if m.downsample is not None:
|
||||||
nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:]))))
|
nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:]))))
|
||||||
'''
|
'''
|
||||||
elif isinstance(m, nn.Linear):
|
elif isinstance(m, ml.Linear):
|
||||||
nn.init.constant_(m.weight, 0)
|
nn.init.constant_(m.weight, 0)
|
||||||
nn.init.constant_(m.bias, 0)'''
|
nn.init.constant_(m.bias, 0)'''
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ import torch.nn.functional as F
|
||||||
from models.arch_util import ResBlock
|
from models.arch_util import ResBlock
|
||||||
from models.lucidrains.x_transformers import Encoder
|
from models.lucidrains.x_transformers import Encoder
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
|
|
||||||
class VitLatent(nn.Module):
|
class VitLatent(nn.Module):
|
||||||
|
@ -31,10 +32,10 @@ class VitLatent(nn.Module):
|
||||||
do_checkpointing=True
|
do_checkpointing=True
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mlp = nn.Sequential(nn.Linear(hidden_dim, hidden_dim*2),
|
self.mlp = nn.Sequential(ml.Linear(hidden_dim, hidden_dim*2),
|
||||||
nn.BatchNorm1d(hidden_dim*2),
|
nn.BatchNorm1d(hidden_dim*2),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Linear(hidden_dim*2, hidden_dim))
|
ml.Linear(hidden_dim*2, hidden_dim))
|
||||||
|
|
||||||
def provide_ema(self, ema):
|
def provide_ema(self, ema):
|
||||||
self.ema = ema
|
self.ema = ema
|
||||||
|
|
|
@ -7,6 +7,7 @@ import torch.nn.functional as F
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
from rotary_embedding_torch import apply_rotary_emb
|
from rotary_embedding_torch import apply_rotary_emb
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
# helpers
|
# helpers
|
||||||
|
|
||||||
|
@ -47,9 +48,9 @@ class Attention(nn.Module):
|
||||||
self.stable = stable
|
self.stable = stable
|
||||||
self.causal = causal
|
self.causal = causal
|
||||||
|
|
||||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
self.to_qkv = ml.Linear(dim, inner_dim * 3, bias = False)
|
||||||
self.to_out = nn.Sequential(
|
self.to_out = nn.Sequential(
|
||||||
nn.Linear(inner_dim, dim),
|
ml.Linear(inner_dim, dim),
|
||||||
nn.Dropout(dropout)
|
nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -102,10 +103,10 @@ class SparseConvCausalAttention(nn.Module):
|
||||||
|
|
||||||
self.stable = stable
|
self.stable = stable
|
||||||
|
|
||||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
self.to_qkv = ml.Linear(dim, inner_dim * 3, bias = False)
|
||||||
|
|
||||||
self.to_out = nn.Sequential(
|
self.to_out = nn.Sequential(
|
||||||
nn.Linear(inner_dim, dim),
|
ml.Linear(inner_dim, dim),
|
||||||
nn.Dropout(dropout)
|
nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -222,10 +223,10 @@ class SparseAxialCausalAttention(nn.Module):
|
||||||
|
|
||||||
self.stable = stable
|
self.stable = stable
|
||||||
|
|
||||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
self.to_qkv = ml.Linear(dim, inner_dim * 3, bias = False)
|
||||||
|
|
||||||
self.to_out = nn.Sequential(
|
self.to_out = nn.Sequential(
|
||||||
nn.Linear(inner_dim, dim),
|
ml.Linear(inner_dim, dim),
|
||||||
nn.Dropout(dropout)
|
nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,7 @@ from models.lucidrains.dalle.attention import Attention, SparseAttention, Sparse
|
||||||
|
|
||||||
from rotary_embedding_torch import RotaryEmbedding, broadcat
|
from rotary_embedding_torch import RotaryEmbedding, broadcat
|
||||||
from g_mlp_pytorch import gMLPBlock
|
from g_mlp_pytorch import gMLPBlock
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
# helpers
|
# helpers
|
||||||
|
|
||||||
|
@ -78,10 +79,10 @@ class FeedForward(nn.Module):
|
||||||
def __init__(self, dim, dropout = 0., mult = 4.):
|
def __init__(self, dim, dropout = 0., mult = 4.):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.net = nn.Sequential(
|
self.net = nn.Sequential(
|
||||||
nn.Linear(dim, dim * mult * 2),
|
ml.Linear(dim, dim * mult * 2),
|
||||||
GEGLU(),
|
GEGLU(),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
nn.Linear(dim * mult, dim)
|
ml.Linear(dim * mult, dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
|
@ -21,6 +21,7 @@ try:
|
||||||
APEX_AVAILABLE = True
|
APEX_AVAILABLE = True
|
||||||
except:
|
except:
|
||||||
APEX_AVAILABLE = False
|
APEX_AVAILABLE = False
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
# helpers
|
# helpers
|
||||||
|
|
||||||
|
@ -356,10 +357,10 @@ class FeedForward(nn.Module):
|
||||||
activation = default(activation, nn.GELU)
|
activation = default(activation, nn.GELU)
|
||||||
|
|
||||||
self.glu = glu
|
self.glu = glu
|
||||||
self.w1 = nn.Linear(dim, dim * mult * (2 if glu else 1))
|
self.w1 = ml.Linear(dim, dim * mult * (2 if glu else 1))
|
||||||
self.act = activation()
|
self.act = activation()
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
self.w2 = nn.Linear(dim * mult, dim)
|
self.w2 = ml.Linear(dim * mult, dim)
|
||||||
|
|
||||||
def forward(self, x, **kwargs):
|
def forward(self, x, **kwargs):
|
||||||
if not self.glu:
|
if not self.glu:
|
||||||
|
@ -401,10 +402,10 @@ class Attention(nn.Module):
|
||||||
self.global_heads = heads - local_heads
|
self.global_heads = heads - local_heads
|
||||||
self.local_attn = LocalAttention(window_size = local_window_size, causal = causal, autopad = True, dropout = dropout, look_forward = int(not causal), rel_pos_emb_config = (dim_head, local_heads)) if local_heads > 0 else None
|
self.local_attn = LocalAttention(window_size = local_window_size, causal = causal, autopad = True, dropout = dropout, look_forward = int(not causal), rel_pos_emb_config = (dim_head, local_heads)) if local_heads > 0 else None
|
||||||
|
|
||||||
self.to_q = nn.Linear(dim, inner_dim, bias = qkv_bias)
|
self.to_q = ml.Linear(dim, inner_dim, bias = qkv_bias)
|
||||||
self.to_k = nn.Linear(dim, inner_dim, bias = qkv_bias)
|
self.to_k = ml.Linear(dim, inner_dim, bias = qkv_bias)
|
||||||
self.to_v = nn.Linear(dim, inner_dim, bias = qkv_bias)
|
self.to_v = ml.Linear(dim, inner_dim, bias = qkv_bias)
|
||||||
self.to_out = nn.Linear(inner_dim, dim, bias = attn_out_bias)
|
self.to_out = ml.Linear(inner_dim, dim, bias = attn_out_bias)
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
def forward(self, x, pos_emb = None, context = None, mask = None, context_mask = None, **kwargs):
|
def forward(self, x, pos_emb = None, context = None, mask = None, context_mask = None, **kwargs):
|
||||||
|
@ -458,7 +459,8 @@ class CrossAttention(Attention):
|
||||||
class AbsolutePositionalEmbedding(nn.Module):
|
class AbsolutePositionalEmbedding(nn.Module):
|
||||||
def __init__(self, dim, max_seq_len):
|
def __init__(self, dim, max_seq_len):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.emb = nn.Embedding(max_seq_len, dim)
|
# nn.Embedding
|
||||||
|
self.emb = ml.Embedding(max_seq_len, dim)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
t = torch.arange(x.shape[1], device=x.device)
|
t = torch.arange(x.shape[1], device=x.device)
|
||||||
|
@ -619,7 +621,8 @@ class PerformerLM(nn.Module):
|
||||||
local_attn_heads = cast_tuple(local_attn_heads)
|
local_attn_heads = cast_tuple(local_attn_heads)
|
||||||
|
|
||||||
self.max_seq_len = max_seq_len
|
self.max_seq_len = max_seq_len
|
||||||
self.token_emb = nn.Embedding(num_tokens, dim)
|
# nn.Embedding
|
||||||
|
self.token_emb = ml.Embedding(num_tokens, dim)
|
||||||
|
|
||||||
if rotary_position_emb:
|
if rotary_position_emb:
|
||||||
self.pos_emb = FixedPositionalEmbedding(dim, max_seq_len)
|
self.pos_emb = FixedPositionalEmbedding(dim, max_seq_len)
|
||||||
|
@ -636,7 +639,7 @@ class PerformerLM(nn.Module):
|
||||||
|
|
||||||
self.performer = Performer(dim, depth, heads, dim_head, local_attn_heads, local_window_size, causal, ff_mult, nb_features, feature_redraw_interval, reversible, ff_chunks, generalized_attention, kernel_fn, use_scalenorm, use_rezero, ff_glu, ff_dropout, attn_dropout, cross_attend, no_projection, auto_check_redraw, qkv_bias, attn_out_bias, shift_tokens)
|
self.performer = Performer(dim, depth, heads, dim_head, local_attn_heads, local_window_size, causal, ff_mult, nb_features, feature_redraw_interval, reversible, ff_chunks, generalized_attention, kernel_fn, use_scalenorm, use_rezero, ff_glu, ff_dropout, attn_dropout, cross_attend, no_projection, auto_check_redraw, qkv_bias, attn_out_bias, shift_tokens)
|
||||||
self.norm = nn.LayerNorm(dim)
|
self.norm = nn.LayerNorm(dim)
|
||||||
self.to_out = nn.Linear(dim, num_tokens) if not tie_embed else None
|
self.to_out = ml.Linear(dim, num_tokens) if not tie_embed else None
|
||||||
|
|
||||||
def check_redraw_projections(self):
|
def check_redraw_projections(self):
|
||||||
self.performer.check_redraw_projections()
|
self.performer.check_redraw_projections()
|
||||||
|
|
|
@ -8,6 +8,7 @@ from torch.cuda.amp import autocast
|
||||||
|
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
|
|
||||||
def par(t, nm):
|
def par(t, nm):
|
||||||
|
@ -355,9 +356,9 @@ class VectorQuantize(nn.Module):
|
||||||
|
|
||||||
codebook_dim = default(codebook_dim, dim)
|
codebook_dim = default(codebook_dim, dim)
|
||||||
requires_projection = codebook_dim != dim
|
requires_projection = codebook_dim != dim
|
||||||
self.project_in = nn.Linear(dim, codebook_dim) if requires_projection \
|
self.project_in = ml.Linear(dim, codebook_dim) if requires_projection \
|
||||||
else nn.Identity()
|
else nn.Identity()
|
||||||
self.project_out = nn.Linear(codebook_dim, dim) if requires_projection \
|
self.project_out = ml.Linear(codebook_dim, dim) if requires_projection \
|
||||||
else nn.Identity()
|
else nn.Identity()
|
||||||
|
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
|
|
@ -11,6 +11,7 @@ from einops import rearrange, repeat, reduce
|
||||||
from einops.layers.torch import Rearrange
|
from einops.layers.torch import Rearrange
|
||||||
|
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
DEFAULT_DIM_HEAD = 64
|
DEFAULT_DIM_HEAD = 64
|
||||||
|
|
||||||
|
@ -125,7 +126,8 @@ class AbsolutePositionalEmbedding(nn.Module):
|
||||||
def __init__(self, dim, max_seq_len):
|
def __init__(self, dim, max_seq_len):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scale = dim ** -0.5
|
self.scale = dim ** -0.5
|
||||||
self.emb = nn.Embedding(max_seq_len, dim)
|
# nn.Embedding
|
||||||
|
self.emb = ml.Embedding(max_seq_len, dim)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
n = torch.arange(x.shape[1], device=x.device)
|
n = torch.arange(x.shape[1], device=x.device)
|
||||||
|
@ -154,7 +156,8 @@ class RelativePositionBias(nn.Module):
|
||||||
self.causal = causal
|
self.causal = causal
|
||||||
self.num_buckets = num_buckets
|
self.num_buckets = num_buckets
|
||||||
self.max_distance = max_distance
|
self.max_distance = max_distance
|
||||||
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
|
# nn.Embedding
|
||||||
|
self.relative_attention_bias = ml.Embedding(num_buckets, heads)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
|
def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
|
||||||
|
@ -360,7 +363,7 @@ class RMSScaleShiftNorm(nn.Module):
|
||||||
self.cdim = 1
|
self.cdim = 1
|
||||||
self.pdim = -1
|
self.pdim = -1
|
||||||
else:
|
else:
|
||||||
self.scale_shift_process = nn.Linear(embed_dim, dim * 2, bias=bias)
|
self.scale_shift_process = ml.Linear(embed_dim, dim * 2, bias=bias)
|
||||||
self.cdim = -1
|
self.cdim = -1
|
||||||
self.pdim = 1
|
self.pdim = 1
|
||||||
|
|
||||||
|
@ -447,7 +450,7 @@ class GLU(nn.Module):
|
||||||
def __init__(self, dim_in, dim_out, activation):
|
def __init__(self, dim_in, dim_out, activation):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.act = activation
|
self.act = activation
|
||||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
self.proj = ml.Linear(dim_in, dim_out * 2)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||||
|
@ -472,7 +475,7 @@ class FeedForward(nn.Module):
|
||||||
activation = ReluSquared() if relu_squared else nn.GELU()
|
activation = ReluSquared() if relu_squared else nn.GELU()
|
||||||
|
|
||||||
project_in = nn.Sequential(
|
project_in = nn.Sequential(
|
||||||
nn.Linear(dim, inner_dim),
|
ml.Linear(dim, inner_dim),
|
||||||
activation
|
activation
|
||||||
) if not glu else GLU(dim, inner_dim, activation)
|
) if not glu else GLU(dim, inner_dim, activation)
|
||||||
|
|
||||||
|
@ -480,7 +483,7 @@ class FeedForward(nn.Module):
|
||||||
project_in,
|
project_in,
|
||||||
nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(),
|
nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
nn.Linear(inner_dim, dim_out)
|
ml.Linear(inner_dim, dim_out)
|
||||||
)
|
)
|
||||||
|
|
||||||
# init last linear layer to 0
|
# init last linear layer to 0
|
||||||
|
@ -535,16 +538,16 @@ class Attention(nn.Module):
|
||||||
qk_dim = int(collab_compression * qk_dim)
|
qk_dim = int(collab_compression * qk_dim)
|
||||||
self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim))
|
self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim))
|
||||||
|
|
||||||
self.to_q = nn.Linear(dim, qk_dim, bias=False)
|
self.to_q = ml.Linear(dim, qk_dim, bias=False)
|
||||||
self.to_k = nn.Linear(dim, qk_dim, bias=False)
|
self.to_k = ml.Linear(dim, qk_dim, bias=False)
|
||||||
self.to_v = nn.Linear(dim, v_dim, bias=False)
|
self.to_v = ml.Linear(dim, v_dim, bias=False)
|
||||||
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
# add GLU gating for aggregated values, from alphafold2
|
# add GLU gating for aggregated values, from alphafold2
|
||||||
self.to_v_gate = None
|
self.to_v_gate = None
|
||||||
if gate_values:
|
if gate_values:
|
||||||
self.to_v_gate = nn.Linear(dim, v_dim)
|
self.to_v_gate = ml.Linear(dim, v_dim)
|
||||||
nn.init.constant_(self.to_v_gate.weight, 0)
|
nn.init.constant_(self.to_v_gate.weight, 0)
|
||||||
nn.init.constant_(self.to_v_gate.bias, 1)
|
nn.init.constant_(self.to_v_gate.bias, 1)
|
||||||
|
|
||||||
|
@ -581,7 +584,7 @@ class Attention(nn.Module):
|
||||||
# attention on attention
|
# attention on attention
|
||||||
self.attn_on_attn = on_attn
|
self.attn_on_attn = on_attn
|
||||||
out_dim = default(out_dim, dim)
|
out_dim = default(out_dim, dim)
|
||||||
self.to_out = nn.Sequential(nn.Linear(v_dim, out_dim * 2), nn.GLU()) if on_attn else nn.Linear(v_dim, out_dim)
|
self.to_out = nn.Sequential(ml.Linear(v_dim, out_dim * 2), nn.GLU()) if on_attn else ml.Linear(v_dim, out_dim)
|
||||||
|
|
||||||
self.rel_pos_bias = rel_pos_bias
|
self.rel_pos_bias = rel_pos_bias
|
||||||
if rel_pos_bias:
|
if rel_pos_bias:
|
||||||
|
@ -1077,7 +1080,7 @@ class ViTransformerWrapper(nn.Module):
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
|
|
||||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||||
self.patch_to_embedding = nn.Linear(patch_dim, dim)
|
self.patch_to_embedding = ml.Linear(patch_dim, dim)
|
||||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||||
self.dropout = nn.Dropout(emb_dropout)
|
self.dropout = nn.Dropout(emb_dropout)
|
||||||
|
|
||||||
|
@ -1135,18 +1138,19 @@ class TransformerWrapper(nn.Module):
|
||||||
self.max_mem_len = max_mem_len
|
self.max_mem_len = max_mem_len
|
||||||
self.shift_mem_down = shift_mem_down
|
self.shift_mem_down = shift_mem_down
|
||||||
|
|
||||||
self.token_emb = nn.Embedding(num_tokens, emb_dim)
|
# nn.Embedding
|
||||||
|
self.token_emb = ml.Embedding(num_tokens, emb_dim)
|
||||||
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
|
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
|
||||||
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
|
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
|
||||||
self.emb_dropout = nn.Dropout(emb_dropout)
|
self.emb_dropout = nn.Dropout(emb_dropout)
|
||||||
|
|
||||||
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
|
self.project_emb = ml.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
|
||||||
self.attn_layers = attn_layers
|
self.attn_layers = attn_layers
|
||||||
self.norm = nn.LayerNorm(dim)
|
self.norm = nn.LayerNorm(dim)
|
||||||
|
|
||||||
self.init_()
|
self.init_()
|
||||||
|
|
||||||
self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
|
self.to_logits = ml.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
|
||||||
|
|
||||||
# memory tokens (like [cls]) from Memory Transformers paper
|
# memory tokens (like [cls]) from Memory Transformers paper
|
||||||
num_memory_tokens = default(num_memory_tokens, 0)
|
num_memory_tokens = default(num_memory_tokens, 0)
|
||||||
|
@ -1233,12 +1237,12 @@ class ContinuousTransformerWrapper(nn.Module):
|
||||||
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
|
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
|
||||||
self.emb_dropout = nn.Dropout(emb_dropout)
|
self.emb_dropout = nn.Dropout(emb_dropout)
|
||||||
|
|
||||||
self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
|
self.project_in = ml.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
|
||||||
|
|
||||||
self.attn_layers = attn_layers
|
self.attn_layers = attn_layers
|
||||||
self.norm = nn.LayerNorm(dim)
|
self.norm = nn.LayerNorm(dim)
|
||||||
|
|
||||||
self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
|
self.project_out = ml.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -4,13 +4,15 @@ import torch.nn.functional as F
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
|
|
||||||
from utils.weight_scheduler import LinearDecayWeightScheduler
|
from utils.weight_scheduler import LinearDecayWeightScheduler
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
|
|
||||||
class GumbelQuantizer(nn.Module):
|
class GumbelQuantizer(nn.Module):
|
||||||
def __init__(self, inp_dim, codebook_dim, num_tokens, straight_through=False):
|
def __init__(self, inp_dim, codebook_dim, num_tokens, straight_through=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.to_logits = nn.Conv1d(inp_dim, num_tokens, 1)
|
self.to_logits = nn.Conv1d(inp_dim, num_tokens, 1)
|
||||||
self.codebook = nn.Embedding(num_tokens, codebook_dim)
|
# nn.Embedding
|
||||||
|
self.codebook = ml.Embedding(num_tokens, codebook_dim)
|
||||||
self.straight_through = straight_through
|
self.straight_through = straight_through
|
||||||
self.temperature_scheduler = LinearDecayWeightScheduler(10, 5000, .9, 2000)
|
self.temperature_scheduler = LinearDecayWeightScheduler(10, 5000, .9, 2000)
|
||||||
self.step = 0
|
self.step = 0
|
||||||
|
|
|
@ -4,6 +4,7 @@ import torch.nn.functional as F
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
from models.arch_util import l2norm, sample_vectors, default, ema_inplace
|
from models.arch_util import l2norm, sample_vectors, default, ema_inplace
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
|
|
||||||
def kmeans(samples, num_clusters, num_iters = 10, use_cosine_sim = False):
|
def kmeans(samples, num_clusters, num_iters = 10, use_cosine_sim = False):
|
||||||
|
@ -184,8 +185,8 @@ class VectorQuantize(nn.Module):
|
||||||
|
|
||||||
codebook_dim = default(codebook_dim, dim)
|
codebook_dim = default(codebook_dim, dim)
|
||||||
requires_projection = codebook_dim != dim
|
requires_projection = codebook_dim != dim
|
||||||
self.project_in = nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
|
self.project_in = ml.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
|
||||||
self.project_out = nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
|
self.project_out = ml.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
|
||||||
|
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
|
||||||
|
|
41
codes/torch_intermediary/__init__.py
Normal file
41
codes/torch_intermediary/__init__.py
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
"""
|
||||||
|
from bitsandbytes.nn import Linear8bitLt as Linear
|
||||||
|
from bitsandbytes.nn import StableEmbedding as Embedding
|
||||||
|
from bitsandbytes.optim.adam import Adam8bit as Adam
|
||||||
|
from bitsandbytes.optim.adamw import AdamW8bit as AdamW
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
from torch.nn import Linear
|
||||||
|
from torch.nn import Embedding
|
||||||
|
from torch.optim.adam import Adam
|
||||||
|
from torch.optim.adamw import AdamW
|
||||||
|
"""
|
||||||
|
|
||||||
|
OVERRIDE_LINEAR = False
|
||||||
|
OVERRIDE_EMBEDDING = False
|
||||||
|
OVERRIDE_ADAM = True
|
||||||
|
OVERRIDE_ADAMW = True
|
||||||
|
USE_STABLE_EMBEDDING = True
|
||||||
|
|
||||||
|
if OVERRIDE_LINEAR:
|
||||||
|
from bitsandbytes.nn import Linear8bitLt as Linear
|
||||||
|
else:
|
||||||
|
from torch.nn import Linear
|
||||||
|
|
||||||
|
if OVERRIDE_EMBEDDING:
|
||||||
|
if USE_STABLE_EMBEDDING:
|
||||||
|
from bitsandbytes.nn import StableEmbedding as Embedding
|
||||||
|
else:
|
||||||
|
from bitsandbytes.nn import Embedding as Embedding
|
||||||
|
else:
|
||||||
|
from torch.nn import Embedding
|
||||||
|
|
||||||
|
if OVERRIDE_ADAM:
|
||||||
|
from bitsandbytes.optim.adam import Adam8bit as Adam
|
||||||
|
else:
|
||||||
|
from torch.optim.adam import Adam
|
||||||
|
|
||||||
|
if OVERRIDE_ADAMW:
|
||||||
|
from bitsandbytes.optim.adamw import AdamW8bit as AdamW
|
||||||
|
else:
|
||||||
|
from torch.optim.adamw import AdamW
|
|
@ -21,6 +21,7 @@ import torchvision.utils as utils
|
||||||
|
|
||||||
from utils.loss_accumulator import LossAccumulator, InfStorageLossAccumulator
|
from utils.loss_accumulator import LossAccumulator, InfStorageLossAccumulator
|
||||||
from utils.util import opt_get, denormalize
|
from utils.util import opt_get, denormalize
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
logger = logging.getLogger('base')
|
logger = logging.getLogger('base')
|
||||||
|
|
||||||
|
@ -337,7 +338,7 @@ class ExtensibleTrainer(BaseModel):
|
||||||
for net in self.networks.values():
|
for net in self.networks.values():
|
||||||
for mod in net.modules():
|
for mod in net.modules():
|
||||||
fan_in = -1
|
fan_in = -1
|
||||||
if isinstance(mod, nn.Linear):
|
if isinstance(mod, ml.Linear):
|
||||||
fan_in = mod.weight.data.shape[1]
|
fan_in = mod.weight.data.shape[1]
|
||||||
elif isinstance(mod, nn.Conv1d):
|
elif isinstance(mod, nn.Conv1d):
|
||||||
fan_in = mod.weight.data.shape[0]
|
fan_in = mod.weight.data.shape[0]
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
@ -6,6 +7,7 @@ import torch.nn as nn
|
||||||
import trainer.networks as networks
|
import trainer.networks as networks
|
||||||
import trainer.lr_scheduler as lr_scheduler
|
import trainer.lr_scheduler as lr_scheduler
|
||||||
from .base_model import BaseModel
|
from .base_model import BaseModel
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
logger = logging.getLogger('base')
|
logger = logging.getLogger('base')
|
||||||
|
|
||||||
|
@ -40,7 +42,8 @@ class FeatureModel(BaseModel):
|
||||||
else:
|
else:
|
||||||
if self.rank <= 0:
|
if self.rank <= 0:
|
||||||
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
||||||
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
|
# torch.optim.Adam
|
||||||
|
self.optimizer_G = ml.Adam(optim_params, lr=train_opt['lr_G'],
|
||||||
weight_decay=wd_G,
|
weight_decay=wd_G,
|
||||||
betas=(train_opt['beta1_G'], train_opt['beta2_G']))
|
betas=(train_opt['beta1_G'], train_opt['beta2_G']))
|
||||||
self.optimizers.append(self.optimizer_G)
|
self.optimizers.append(self.optimizer_G)
|
||||||
|
|
|
@ -3,10 +3,10 @@ from collections import Counter
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import torch
|
import torch
|
||||||
from torch.optim.lr_scheduler import _LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
def get_scheduler_for_name(name, optimizers, scheduler_opt):
|
def get_scheduler_for_name(name, optimizers, scheduler_opt):
|
||||||
schedulers = []
|
schedulers = []
|
||||||
for o in optimizers:
|
for o in optimizers:
|
||||||
|
@ -136,7 +136,8 @@ class CosineAnnealingLR_Restart(_LRScheduler):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=1e-4, weight_decay=0,
|
#torch.optim.Adam
|
||||||
|
optimizer = ml.Adam([torch.zeros(3, 64, 3, 3)], lr=1e-4, weight_decay=0,
|
||||||
betas=(0.9, 0.99))
|
betas=(0.9, 0.99))
|
||||||
##############################
|
##############################
|
||||||
# MultiStepLR_Restart
|
# MultiStepLR_Restart
|
||||||
|
|
|
@ -12,6 +12,7 @@ from utils.util import recursively_detach, opt_get, clip_grad_norm
|
||||||
|
|
||||||
logger = logging.getLogger('base')
|
logger = logging.getLogger('base')
|
||||||
|
|
||||||
|
import torch_intermediary as ml
|
||||||
|
|
||||||
# Defines the expected API for a single training step
|
# Defines the expected API for a single training step
|
||||||
class ConfigurableStep(Module):
|
class ConfigurableStep(Module):
|
||||||
|
@ -82,7 +83,8 @@ class ConfigurableStep(Module):
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d,
|
norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d,
|
||||||
nn.BatchNorm3d, nn.InstanceNorm3d, nn.GroupNorm, nn.LayerNorm)
|
nn.BatchNorm3d, nn.InstanceNorm3d, nn.GroupNorm, nn.LayerNorm)
|
||||||
emb_modules = (nn.Embedding, nn.EmbeddingBag)
|
# nn.Embedding
|
||||||
|
emb_modules = (ml.Embedding, nn.EmbeddingBag)
|
||||||
param_names_notweights = set()
|
param_names_notweights = set()
|
||||||
all_param_names = set()
|
all_param_names = set()
|
||||||
param_map = {}
|
param_map = {}
|
||||||
|
@ -123,7 +125,8 @@ class ConfigurableStep(Module):
|
||||||
{ 'params': params_weights, 'weight_decay': opt_get(opt_config, ['weight_decay'], 0) },
|
{ 'params': params_weights, 'weight_decay': opt_get(opt_config, ['weight_decay'], 0) },
|
||||||
{ 'params': params_notweights, 'weight_decay': 0 }
|
{ 'params': params_notweights, 'weight_decay': 0 }
|
||||||
]
|
]
|
||||||
opt = torch.optim.AdamW(groups, lr=opt_config['lr'],
|
# torch.optim.AdamW
|
||||||
|
opt = ml.AdamW(groups, lr=opt_config['lr'],
|
||||||
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
|
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
|
||||||
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
|
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
|
||||||
opt._group_names = [params_names_weights, params_names_notweights]
|
opt._group_names = [params_names_weights, params_names_notweights]
|
||||||
|
@ -141,14 +144,16 @@ class ConfigurableStep(Module):
|
||||||
# The torch ZeRO implementation does not seem to support parameter groups, so do not shard the non-weighted
|
# The torch ZeRO implementation does not seem to support parameter groups, so do not shard the non-weighted
|
||||||
# parameters and just use a normal AdamW implementation. In a large network, these weights will normally
|
# parameters and just use a normal AdamW implementation. In a large network, these weights will normally
|
||||||
# be a tiny fraction of the total weights.
|
# be a tiny fraction of the total weights.
|
||||||
opt_unweighted = torch.optim.AdamW(params_notweights, lr=opt_config['lr'], weight_decay=0,
|
# torch.optim.AdamW
|
||||||
|
opt_unweighted = ml.AdamW(params_notweights, lr=opt_config['lr'], weight_decay=0,
|
||||||
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
|
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
|
||||||
opt_unweighted._config = opt_config
|
opt_unweighted._config = opt_config
|
||||||
opt_unweighted._config['network'] = net_name
|
opt_unweighted._config['network'] = net_name
|
||||||
opt_unweighted._group_names = []
|
opt_unweighted._group_names = []
|
||||||
self.optimizers.append(opt_unweighted)
|
self.optimizers.append(opt_unweighted)
|
||||||
|
|
||||||
opt = ZeroRedundancyOptimizer(params_weights, optimizer_class=torch.optim.AdamW, lr=opt_config['lr'],
|
# torch.optim.AdamW
|
||||||
|
opt = ZeroRedundancyOptimizer(params_weights, optimizer_class=ml.AdamW, lr=opt_config['lr'],
|
||||||
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
|
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
|
||||||
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
|
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
|
||||||
opt.param_groups[0]['initial_lr'] = opt_config['lr']
|
opt.param_groups[0]['initial_lr'] = opt_config['lr']
|
||||||
|
@ -162,7 +167,8 @@ class ConfigurableStep(Module):
|
||||||
opt._group_names = sorted(list(all_param_names))
|
opt._group_names = sorted(list(all_param_names))
|
||||||
elif self.step_opt['optimizer'] == 'lamb':
|
elif self.step_opt['optimizer'] == 'lamb':
|
||||||
from trainer.optimizers.lamb import Lamb
|
from trainer.optimizers.lamb import Lamb
|
||||||
opt_unweighted = torch.optim.AdamW(params_notweights, lr=opt_config['lr'], weight_decay=0,
|
# torch.optim.AdamW
|
||||||
|
opt_unweighted = ml.AdamW(params_notweights, lr=opt_config['lr'], weight_decay=0,
|
||||||
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
|
betas=(opt_get(opt_config, ['beta1'], .9), opt_get(opt_config, ['beta2'], .999)))
|
||||||
opt_unweighted._config = opt_config
|
opt_unweighted._config = opt_config
|
||||||
opt_unweighted._config['network'] = net_name
|
opt_unweighted._config['network'] = net_name
|
||||||
|
|
|
@ -45,4 +45,7 @@ rotary-embedding-torch
|
||||||
axial_positional_embedding
|
axial_positional_embedding
|
||||||
g-mlp-pytorch
|
g-mlp-pytorch
|
||||||
x-clip
|
x-clip
|
||||||
x_transformers==1.0.4
|
x_transformers==1.0.4
|
||||||
|
|
||||||
|
# bitsandbytes
|
||||||
|
bitsandbytes==0.35.0
|
||||||
|
|
Loading…
Reference in New Issue
Block a user