Rework configs to remove redundant code
This commit is contained in:
parent
7bfdad13f8
commit
a2063b7000
|
@ -2,57 +2,72 @@
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
|
|
||||||
class EncoderConfig:
|
class Config:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
|
self.activation_fn = kwargs.pop("activation_fn", "gelu")
|
||||||
|
self.dropout = kwargs.pop("dropout", 0.0)
|
||||||
|
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
|
||||||
|
self.attention_dropout = kwargs.pop("attention_dropout", 0.0)
|
||||||
|
self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
|
||||||
|
self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
|
||||||
|
self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
|
||||||
|
self.moe_freq = kwargs.pop("moe_freq", 0)
|
||||||
|
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
|
||||||
|
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
|
||||||
|
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
|
||||||
|
self.moe_eval_capacity_token_fraction = kwargs.pop(
|
||||||
|
"moe_eval_capacity_token_fraction", 0.25
|
||||||
|
)
|
||||||
|
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
|
||||||
|
self.moe_normalize_gate_prob_before_dropping = kwargs.pop(
|
||||||
|
"moe_normalize_gate_prob_before_dropping", False
|
||||||
|
)
|
||||||
|
self.use_xmoe = kwargs.pop("use_xmoe", False)
|
||||||
|
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
|
||||||
|
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
|
||||||
|
self.deepnorm = kwargs.pop("deepnorm", False)
|
||||||
|
self.subln = kwargs.pop("subln", True)
|
||||||
|
self.bert_init = kwargs.pop("bert_init", False)
|
||||||
|
self.multiway = kwargs.pop("multiway", False)
|
||||||
|
self.no_output_layer = kwargs.pop("no_output_layer", False)
|
||||||
|
self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
|
||||||
|
self.vocab_size = kwargs.pop("vocab_size", -1)
|
||||||
|
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
|
||||||
|
self.fsdp = kwargs.pop("fsdp", False)
|
||||||
|
self.ddp_rank = kwargs.pop("ddp_rank", 0)
|
||||||
|
self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
|
||||||
|
self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
|
||||||
|
|
||||||
|
if self.use_xmoe:
|
||||||
|
self.moe_normalize_gate_prob_before_dropping = True
|
||||||
|
self.moe_second_expert_policy = "random"
|
||||||
|
assert self.moe_freq > 0 and self.moe_expert_count > 0
|
||||||
|
|
||||||
|
def override(self, args):
|
||||||
|
for hp in self.__dict__.keys():
|
||||||
|
if getattr(args, hp, None) is not None:
|
||||||
|
self.__dict__[hp] = getattr(args, hp, None)
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderConfig(Config):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768)
|
self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768)
|
||||||
self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12)
|
self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12)
|
||||||
self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072)
|
self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072)
|
||||||
self.encoder_layers = kwargs.pop("encoder_layers", 12)
|
self.encoder_layers = kwargs.pop("encoder_layers", 12)
|
||||||
self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True)
|
self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True)
|
||||||
self.normalize_output = kwargs.pop("normalize_output", True)
|
self.normalize_output = kwargs.pop("normalize_output", True)
|
||||||
self.activation_fn = kwargs.pop("activation_fn", "gelu")
|
|
||||||
self.dropout = kwargs.pop("dropout", 0.0)
|
|
||||||
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
|
|
||||||
self.attention_dropout = kwargs.pop("attention_dropout", 0.0)
|
|
||||||
self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
|
|
||||||
self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
|
|
||||||
self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
|
|
||||||
self.moe_freq = kwargs.pop("moe_freq", 0)
|
|
||||||
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
|
|
||||||
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
|
|
||||||
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
|
|
||||||
self.moe_eval_capacity_token_fraction = kwargs.pop(
|
|
||||||
"moe_eval_capacity_token_fraction", 0.25
|
|
||||||
)
|
|
||||||
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
|
|
||||||
self.moe_normalize_gate_prob_before_dropping = kwargs.pop(
|
|
||||||
"moe_normalize_gate_prob_before_dropping", False
|
|
||||||
)
|
|
||||||
self.use_xmoe = kwargs.pop("use_xmoe", False)
|
|
||||||
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
|
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
|
||||||
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
|
|
||||||
self.deepnorm = kwargs.pop("deepnorm", False)
|
|
||||||
self.subln = kwargs.pop("subln", True)
|
|
||||||
self.bert_init = kwargs.pop("bert_init", False)
|
|
||||||
self.multiway = kwargs.pop("multiway", False)
|
|
||||||
self.share_encoder_input_output_embed = kwargs.pop(
|
self.share_encoder_input_output_embed = kwargs.pop(
|
||||||
"share_encoder_input_output_embed", False
|
"share_encoder_input_output_embed", False
|
||||||
)
|
)
|
||||||
self.max_source_positions = kwargs.pop("max_source_positions", 1024)
|
self.max_source_positions = kwargs.pop("max_source_positions", 1024)
|
||||||
self.no_output_layer = kwargs.pop("no_output_layer", False)
|
|
||||||
self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
|
|
||||||
# Text
|
|
||||||
self.vocab_size = kwargs.pop("vocab_size", -1)
|
|
||||||
# Vision
|
# Vision
|
||||||
self.img_size = kwargs.pop("img_size", 224)
|
self.img_size = kwargs.pop("img_size", 224)
|
||||||
self.patch_size = kwargs.pop("patch_size", 16)
|
self.patch_size = kwargs.pop("patch_size", 16)
|
||||||
self.in_chans = kwargs.pop("in_chans", 3)
|
self.in_chans = kwargs.pop("in_chans", 3)
|
||||||
# Fairscale
|
|
||||||
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
|
|
||||||
self.fsdp = kwargs.pop("fsdp", False)
|
|
||||||
self.ddp_rank = kwargs.pop("ddp_rank", 0)
|
|
||||||
self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
|
|
||||||
self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
|
|
||||||
|
|
||||||
if self.deepnorm:
|
if self.deepnorm:
|
||||||
self.encoder_normalize_before = False
|
self.encoder_normalize_before = False
|
||||||
|
@ -60,63 +75,22 @@ class EncoderConfig:
|
||||||
if self.subln:
|
if self.subln:
|
||||||
self.encoder_normalize_before = True
|
self.encoder_normalize_before = True
|
||||||
self.deepnorm = False
|
self.deepnorm = False
|
||||||
if self.use_xmoe:
|
|
||||||
self.moe_normalize_gate_prob_before_dropping = True
|
|
||||||
self.moe_second_expert_policy = "random"
|
|
||||||
assert self.moe_freq > 0 and self.moe_expert_count > 0
|
|
||||||
|
|
||||||
def override(self, args):
|
|
||||||
for hp in self.__dict__.keys():
|
|
||||||
if getattr(args, hp, None) is not None:
|
|
||||||
self.__dict__[hp] = getattr(args, hp, None)
|
|
||||||
|
|
||||||
|
|
||||||
class DecoderConfig:
|
class DecoderConfig(Config):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
|
self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
|
||||||
self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12)
|
self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12)
|
||||||
self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072)
|
self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072)
|
||||||
self.decoder_layers = kwargs.pop("decoder_layers", 12)
|
self.decoder_layers = kwargs.pop("decoder_layers", 12)
|
||||||
self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True)
|
self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True)
|
||||||
self.activation_fn = kwargs.pop("activation_fn", "gelu")
|
|
||||||
self.dropout = kwargs.pop("dropout", 0.0)
|
|
||||||
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
|
|
||||||
self.attention_dropout = kwargs.pop("attention_dropout", 0.0)
|
|
||||||
self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
|
|
||||||
self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
|
|
||||||
self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
|
|
||||||
self.moe_freq = kwargs.pop("moe_freq", 0)
|
|
||||||
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
|
|
||||||
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
|
|
||||||
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
|
|
||||||
self.moe_eval_capacity_token_fraction = kwargs.pop(
|
|
||||||
"moe_eval_capacity_token_fraction", 0.25
|
|
||||||
)
|
|
||||||
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
|
|
||||||
self.moe_normalize_gate_prob_before_dropping = kwargs.pop(
|
|
||||||
"moe_normalize_gate_prob_before_dropping", False
|
|
||||||
)
|
|
||||||
self.use_xmoe = kwargs.pop("use_xmoe", False)
|
|
||||||
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
|
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
|
||||||
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
|
|
||||||
self.deepnorm = kwargs.pop("deepnorm", False)
|
|
||||||
self.subln = kwargs.pop("subln", True)
|
|
||||||
self.bert_init = kwargs.pop("bert_init", False)
|
|
||||||
self.multiway = kwargs.pop("multiway", False)
|
|
||||||
self.share_decoder_input_output_embed = kwargs.pop(
|
self.share_decoder_input_output_embed = kwargs.pop(
|
||||||
"share_decoder_input_output_embed", False
|
"share_decoder_input_output_embed", False
|
||||||
)
|
)
|
||||||
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
|
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
|
||||||
self.no_output_layer = kwargs.pop("no_output_layer", False)
|
|
||||||
self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
|
|
||||||
# Text
|
|
||||||
self.vocab_size = kwargs.pop("vocab_size", -1)
|
|
||||||
# Fairscale
|
|
||||||
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
|
|
||||||
self.fsdp = kwargs.pop("fsdp", False)
|
|
||||||
self.ddp_rank = kwargs.pop("ddp_rank", 0)
|
|
||||||
self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
|
|
||||||
self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
|
|
||||||
|
|
||||||
if self.deepnorm:
|
if self.deepnorm:
|
||||||
self.decoder_normalize_before = False
|
self.decoder_normalize_before = False
|
||||||
|
@ -124,85 +98,10 @@ class DecoderConfig:
|
||||||
if self.subln:
|
if self.subln:
|
||||||
self.decoder_normalize_before = True
|
self.decoder_normalize_before = True
|
||||||
self.deepnorm = False
|
self.deepnorm = False
|
||||||
if self.use_xmoe:
|
|
||||||
self.moe_normalize_gate_prob_before_dropping = True
|
|
||||||
self.moe_second_expert_policy = "random"
|
|
||||||
assert self.moe_freq > 0 and self.moe_expert_count > 0
|
|
||||||
|
|
||||||
def override(self, args):
|
|
||||||
for hp in self.__dict__.keys():
|
|
||||||
if getattr(args, hp, None) is not None:
|
|
||||||
self.__dict__[hp] = getattr(args, hp, None)
|
|
||||||
|
|
||||||
|
|
||||||
class EncoderDecoderConfig:
|
|
||||||
|
class EncoderDecoderConfig(EncoderConfig, DecoderConfig):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768)
|
super().__init__(**kwargs)
|
||||||
self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12)
|
|
||||||
self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072)
|
|
||||||
self.encoder_layers = kwargs.pop("encoder_layers", 12)
|
|
||||||
self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True)
|
|
||||||
self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
|
|
||||||
self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12)
|
|
||||||
self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072)
|
|
||||||
self.decoder_layers = kwargs.pop("decoder_layers", 12)
|
|
||||||
self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True)
|
|
||||||
self.activation_fn = kwargs.pop("activation_fn", "gelu")
|
|
||||||
self.dropout = kwargs.pop("dropout", 0.0)
|
|
||||||
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
|
|
||||||
self.attention_dropout = kwargs.pop("attention_dropout", 0.0)
|
|
||||||
self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
|
|
||||||
self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
|
|
||||||
self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
|
|
||||||
self.moe_freq = kwargs.pop("moe_freq", 0)
|
|
||||||
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
|
|
||||||
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
|
|
||||||
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
|
|
||||||
self.moe_eval_capacity_token_fraction = kwargs.pop(
|
|
||||||
"moe_eval_capacity_token_fraction", 0.25
|
|
||||||
)
|
|
||||||
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
|
|
||||||
self.moe_normalize_gate_prob_before_dropping = kwargs.pop(
|
|
||||||
"moe_normalize_gate_prob_before_dropping", False
|
|
||||||
)
|
|
||||||
self.use_xmoe = kwargs.pop("use_xmoe", False)
|
|
||||||
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
|
|
||||||
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
|
|
||||||
self.deepnorm = kwargs.pop("deepnorm", False)
|
|
||||||
self.subln = kwargs.pop("subln", True)
|
|
||||||
self.bert_init = kwargs.pop("bert_init", False)
|
|
||||||
self.multiway = kwargs.pop("multiway", False)
|
|
||||||
self.share_all_embeddings = kwargs.pop("share_all_embeddings", False)
|
self.share_all_embeddings = kwargs.pop("share_all_embeddings", False)
|
||||||
self.share_decoder_input_output_embed = kwargs.pop(
|
|
||||||
"share_decoder_input_output_embed", False
|
|
||||||
)
|
|
||||||
self.max_source_positions = kwargs.pop("max_source_positions", 1024)
|
|
||||||
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
|
|
||||||
self.no_output_layer = kwargs.pop("no_output_layer", False)
|
|
||||||
self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
|
|
||||||
# Text
|
|
||||||
self.vocab_size = kwargs.pop("vocab_size", -1)
|
|
||||||
# Fairscale
|
|
||||||
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
|
|
||||||
self.fsdp = kwargs.pop("fsdp", False)
|
|
||||||
self.ddp_rank = kwargs.pop("ddp_rank", 0)
|
|
||||||
self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
|
|
||||||
self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
|
|
||||||
|
|
||||||
if self.deepnorm:
|
|
||||||
self.encoder_normalize_before = False
|
|
||||||
self.decoder_normalize_before = False
|
|
||||||
self.subln = False
|
|
||||||
if self.subln:
|
|
||||||
self.encoder_normalize_before = True
|
|
||||||
self.decoder_normalize_before = True
|
|
||||||
self.deepnorm = False
|
|
||||||
if self.use_xmoe:
|
|
||||||
self.moe_normalize_gate_prob_before_dropping = True
|
|
||||||
self.moe_second_expert_policy = "random"
|
|
||||||
assert self.moe_freq > 0 and self.moe_expert_count > 0
|
|
||||||
|
|
||||||
def override(self, args):
|
|
||||||
for hp in self.__dict__.keys():
|
|
||||||
if getattr(args, hp, None) is not None:
|
|
||||||
self.__dict__[hp] = getattr(args, hp, None)
|
|
||||||
|
|
|
@ -9,12 +9,14 @@ import torch.nn as nn
|
||||||
from fairscale.nn import checkpoint_wrapper, wrap
|
from fairscale.nn import checkpoint_wrapper, wrap
|
||||||
|
|
||||||
from torchscale.architecture.utils import init_bert_params
|
from torchscale.architecture.utils import init_bert_params
|
||||||
|
from torchscale.architecture.config import DecoderConfig, EncoderConfig
|
||||||
from torchscale.component.droppath import DropPath
|
from torchscale.component.droppath import DropPath
|
||||||
from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
|
from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
|
||||||
from torchscale.component.multihead_attention import MultiheadAttention
|
from torchscale.component.multihead_attention import MultiheadAttention
|
||||||
from torchscale.component.relative_position_bias import RelativePositionBias
|
from torchscale.component.relative_position_bias import RelativePositionBias
|
||||||
from torchscale.component.xmoe.moe_layer import MOELayer
|
from torchscale.component.xmoe.moe_layer import MOELayer
|
||||||
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
|
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from apex.normalization import FusedLayerNorm as LayerNorm
|
from apex.normalization import FusedLayerNorm as LayerNorm
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
|
@ -23,7 +25,7 @@ except ModuleNotFoundError:
|
||||||
class DecoderLayer(nn.Module):
|
class DecoderLayer(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
args,
|
args: DecoderConfig,
|
||||||
depth,
|
depth,
|
||||||
is_moe_layer=False,
|
is_moe_layer=False,
|
||||||
is_encoder_decoder=False,
|
is_encoder_decoder=False,
|
||||||
|
@ -209,7 +211,7 @@ class DecoderLayer(nn.Module):
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
args,
|
args: DecoderConfig,
|
||||||
embed_tokens=None,
|
embed_tokens=None,
|
||||||
embed_positions=None,
|
embed_positions=None,
|
||||||
output_projection=None,
|
output_projection=None,
|
||||||
|
|
|
@ -13,6 +13,7 @@ except ModuleNotFoundError:
|
||||||
from torch.nn import LayerNorm
|
from torch.nn import LayerNorm
|
||||||
|
|
||||||
from torchscale.architecture.utils import init_bert_params
|
from torchscale.architecture.utils import init_bert_params
|
||||||
|
from torchscale.architecture.config import EncoderConfig
|
||||||
from torchscale.component.droppath import DropPath
|
from torchscale.component.droppath import DropPath
|
||||||
from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
|
from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
|
||||||
from torchscale.component.multihead_attention import MultiheadAttention
|
from torchscale.component.multihead_attention import MultiheadAttention
|
||||||
|
@ -23,7 +24,11 @@ from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
|
||||||
|
|
||||||
|
|
||||||
class EncoderLayer(nn.Module):
|
class EncoderLayer(nn.Module):
|
||||||
def __init__(self, args, depth, is_moe_layer=False, is_encoder_decoder=False):
|
def __init__(self,
|
||||||
|
args: EncoderConfig,
|
||||||
|
depth,
|
||||||
|
is_moe_layer: bool = False,
|
||||||
|
is_encoder_decoder: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.args = args
|
self.args = args
|
||||||
self.embed_dim = args.encoder_embed_dim
|
self.embed_dim = args.encoder_embed_dim
|
||||||
|
@ -165,11 +170,11 @@ class EncoderLayer(nn.Module):
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
args,
|
args: EncoderConfig,
|
||||||
embed_tokens=None,
|
embed_tokens=None,
|
||||||
embed_positions=None,
|
embed_positions=None,
|
||||||
output_projection=None,
|
output_projection=None,
|
||||||
is_encoder_decoder=False,
|
is_encoder_decoder: bool = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
self.args = args
|
self.args = args
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from torchscale.architecture.config import EncoderDecoderConfig
|
||||||
from torchscale.architecture.decoder import Decoder
|
from torchscale.architecture.decoder import Decoder
|
||||||
from torchscale.architecture.encoder import Encoder
|
from torchscale.architecture.encoder import Encoder
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user