vall-e/vall_e/models/arch/__init__.py

65 lines
1.8 KiB
Python
Raw Normal View History

2024-06-06 01:30:43 +00:00
AVAILABLE_ARCHES = []
ERROR_ARCHES = {}
2024-06-06 01:30:43 +00:00
try:
from .transformer import SinusoidalEmbedding, Block as TransformerBlock
AVAILABLE_ARCHES.append("transformer")
except Exception as e:
ERROR_ARCHES["transformer"] = e
2024-06-06 01:30:43 +00:00
pass
try:
from .retnet import RetNetDecoder, RetNetConfig
AVAILABLE_ARCHES.append("retnet")
except Exception as e:
ERROR_ARCHES["retnet"] = e
2024-06-06 01:30:43 +00:00
pass
try:
from .retnet_syncdoth.retnet_ts import RetNetDecoder as RetNetDecoder_TS, RetNetConfig as RetNetConfig_TS
AVAILABLE_ARCHES.append("retnet-ts")
except Exception as e:
ERROR_ARCHES["retnet-ts"] = e
2024-06-06 01:30:43 +00:00
pass
try:
from .retnet_syncdoth.retnet_hf import RetNetDecoder as RetNetDecoder_HF, RetNetConfig as RetNetConfig_HF, RetNetForCausalLM
AVAILABLE_ARCHES.append("retnet-hf")
except Exception as e:
ERROR_ARCHES["retnet-hf"] = e
2024-06-06 01:30:43 +00:00
pass
try:
2024-11-10 00:04:59 +00:00
from .llama import LlamaModel, LlamaModel_Adapted, LlamaConfig, AVAILABLE_ATTENTIONS, LlamaAttention, LlamaAttention_Adapted, LlamaDecoderLayer, LlamaDecoderLayer_Adapted, LlamaForCausalLM
2024-06-06 01:30:43 +00:00
AVAILABLE_ARCHES.append("llama")
except Exception as e:
ERROR_ARCHES["llama"] = e
AVAILABLE_ATTENTIONS = []
2024-06-06 01:30:43 +00:00
pass
try:
from .bitnet import BitNetTransformer
AVAILABLE_ARCHES.append("bitnet")
except Exception as e:
ERROR_ARCHES["bitnet"] = e
2024-06-06 01:30:43 +00:00
pass
try:
from .mixtral import MixtralModel, MixtralConfig, MixtralAttention, MixtralAttention_Adapted, load_balancing_loss_func
2024-06-06 01:30:43 +00:00
AVAILABLE_ARCHES.append("mixtral")
except Exception as e:
ERROR_ARCHES["mixtral"] = e
2024-06-06 01:30:43 +00:00
try:
from .mamba import MambaMixelModel, MambaLMHeadModel, MambaConfig
2024-06-06 01:30:43 +00:00
AVAILABLE_ARCHES.append("mamba")
AVAILABLE_ARCHES.append("mamba2")
except Exception as e:
ERROR_ARCHES["mamba"] = e
ERROR_ARCHES["mamba2"] = e
try:
from .mamba_vasqu import Mamba2Model_HF, Mamba2Config_HF
AVAILABLE_ARCHES.append("mamba2-hf")
except Exception as e:
ERROR_ARCHES["mamba2-hf"] = e