flake8 lint checks
This commit is contained in:
parent
4714557e89
commit
994e4665a2
|
@ -1,6 +1,7 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
# Copyright (c) 2022 Microsoft
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
|
# flake8: noqa
|
||||||
import models
|
import models
|
||||||
import tasks
|
import tasks
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
# Copyright (c) 2022 Microsoft
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
|
# flake8: noqa
|
||||||
import models
|
import models
|
||||||
import tasks
|
import tasks
|
||||||
|
|
||||||
|
|
|
@ -1,24 +1,21 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
# Copyright (c) 2022 Microsoft
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
import math
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Optional
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from fairseq import utils
|
from fairseq import utils
|
||||||
from fairseq.distributed import fsdp_wrap
|
from fairseq.models import BaseFairseqModel, register_model, register_model_architecture
|
||||||
from fairseq.models import BaseFairseqModel, FairseqIncrementalDecoder, register_model, register_model_architecture
|
|
||||||
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
||||||
from fairseq.models.transformer import (
|
from fairseq.models.transformer import (
|
||||||
DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
|
DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
|
||||||
)
|
)
|
||||||
from fairseq.modules import PositionalEmbedding
|
from fairseq.modules import PositionalEmbedding
|
||||||
from fairseq.models.squad import SQuADHead
|
from fairseq.models.squad import SQuADHead
|
||||||
from torch import Tensor
|
|
||||||
from omegaconf import II
|
from omegaconf import II
|
||||||
from .machine_translation import MTEncoder as Encoder
|
from .machine_translation import MTEncoder as Encoder
|
||||||
from torchscale.architecture.config import EncoderConfig
|
from torchscale.architecture.config import EncoderConfig
|
||||||
|
@ -28,6 +25,7 @@ DEFAULT_MAX_SOURCE_POSITIONS = 1024
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BertConfig(FairseqDataclass):
|
class BertConfig(FairseqDataclass):
|
||||||
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
|
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
|
||||||
|
@ -177,7 +175,10 @@ class BertConfig(FairseqDataclass):
|
||||||
moe_eval_capacity_token_fraction: Optional[float] = field(
|
moe_eval_capacity_token_fraction: Optional[float] = field(
|
||||||
default=0.25,
|
default=0.25,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Default: 0.25, Fraction of tokens as capacity during validation, if set to negative, use same as training. range: (0.0, 1.0]."
|
"help": (
|
||||||
|
"Default: 0.25, Fraction of tokens as capacity during validation, "
|
||||||
|
"if set to negative, use same as training. range: (0.0, 1.0]."
|
||||||
|
)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
moe_normalize_expert_grad: Optional[str] = field(
|
moe_normalize_expert_grad: Optional[str] = field(
|
||||||
|
@ -190,7 +191,8 @@ class BertConfig(FairseqDataclass):
|
||||||
default=False, metadata={"help": "records all to all perf stats during distributed training"}
|
default=False, metadata={"help": "records all to all perf stats during distributed training"}
|
||||||
)
|
)
|
||||||
dummy_a2a: Optional[bool] = field(
|
dummy_a2a: Optional[bool] = field(
|
||||||
default=False, metadata={"help": "By passes all to all during distributed training by returning the input buffer as output"}
|
default=False, metadata={
|
||||||
|
"help": "By passes all to all during distributed training by returning the input buffer as output"}
|
||||||
)
|
)
|
||||||
moe_batch_prioritized_routing: Optional[bool] = field(
|
moe_batch_prioritized_routing: Optional[bool] = field(
|
||||||
default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."}
|
default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."}
|
||||||
|
@ -350,7 +352,8 @@ class BertModel(BaseFairseqModel):
|
||||||
return_all_hiddens=False,
|
return_all_hiddens=False,
|
||||||
classification_head_name=None,
|
classification_head_name=None,
|
||||||
masked_tokens=None,
|
masked_tokens=None,
|
||||||
**kwargs):
|
**kwargs
|
||||||
|
):
|
||||||
encoder_out = self.encoder(src_tokens, features_only=True, return_all_hiddens=return_all_hiddens)
|
encoder_out = self.encoder(src_tokens, features_only=True, return_all_hiddens=return_all_hiddens)
|
||||||
x, extra = encoder_out["encoder_out"], encoder_out
|
x, extra = encoder_out["encoder_out"], encoder_out
|
||||||
x = x.transpose(0, 1)
|
x = x.transpose(0, 1)
|
||||||
|
@ -389,6 +392,7 @@ class ClassificationHead(nn.Module):
|
||||||
x = self.out_proj(x)
|
x = self.out_proj(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class LMHead(nn.Module):
|
class LMHead(nn.Module):
|
||||||
"""Head for masked language modeling."""
|
"""Head for masked language modeling."""
|
||||||
|
|
||||||
|
|
|
@ -6,12 +6,12 @@
|
||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import math
|
import logging
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from fairseq import options, utils
|
from fairseq import utils
|
||||||
from fairseq import distributed_utils
|
from fairseq import distributed_utils
|
||||||
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
||||||
from fairseq.models import (
|
from fairseq.models import (
|
||||||
|
@ -29,9 +29,9 @@ from torchscale.architecture.config import DecoderConfig
|
||||||
from omegaconf import II
|
from omegaconf import II
|
||||||
|
|
||||||
DEFAULT_MAX_TARGET_POSITIONS = 1024
|
DEFAULT_MAX_TARGET_POSITIONS = 1024
|
||||||
import logging
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LanguageConfig(FairseqDataclass):
|
class LanguageConfig(FairseqDataclass):
|
||||||
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
|
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
|
||||||
|
@ -151,7 +151,10 @@ class LanguageConfig(FairseqDataclass):
|
||||||
moe_eval_capacity_token_fraction: Optional[float] = field(
|
moe_eval_capacity_token_fraction: Optional[float] = field(
|
||||||
default=0.25,
|
default=0.25,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Default: 0.25, Fraction of tokens as capacity during validation, if set to negative, use same as training. range: (0.0, 1.0]."
|
"help": (
|
||||||
|
"Default: 0.25, Fraction of tokens as capacity during validation, "
|
||||||
|
"if set to negative, use same as training. range: (0.0, 1.0]."
|
||||||
|
)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
moe_normalize_expert_grad: Optional[str] = field(
|
moe_normalize_expert_grad: Optional[str] = field(
|
||||||
|
@ -164,7 +167,8 @@ class LanguageConfig(FairseqDataclass):
|
||||||
default=False, metadata={"help": "records all to all perf stats during distributed training"}
|
default=False, metadata={"help": "records all to all perf stats during distributed training"}
|
||||||
)
|
)
|
||||||
dummy_a2a: Optional[bool] = field(
|
dummy_a2a: Optional[bool] = field(
|
||||||
default=False, metadata={"help": "By passes all to all during distributed training by returning the input buffer as output"}
|
default=False, metadata={
|
||||||
|
"help": "By passes all to all during distributed training by returning the input buffer as output"}
|
||||||
)
|
)
|
||||||
moe_batch_prioritized_routing: Optional[bool] = field(
|
moe_batch_prioritized_routing: Optional[bool] = field(
|
||||||
default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."}
|
default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."}
|
||||||
|
@ -238,10 +242,10 @@ class LanguageModel(FairseqLanguageModel):
|
||||||
output_projection.weight = embed_tokens.weight
|
output_projection.weight = embed_tokens.weight
|
||||||
else:
|
else:
|
||||||
output_projection = torch.nn.Linear(
|
output_projection = torch.nn.Linear(
|
||||||
decoder_embed_dim, len(task.dictionary), bias=False
|
args.decoder_embed_dim, len(task.dictionary), bias=False
|
||||||
)
|
)
|
||||||
torch.nn.init.normal_(
|
torch.nn.init.normal_(
|
||||||
output_projection.weight, mean=0, std=decoder_embed_dim ** -0.5
|
output_projection.weight, mean=0, std=args.decoder_embed_dim ** -0.5
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -252,7 +256,8 @@ class LanguageModel(FairseqLanguageModel):
|
||||||
and getattr(args, 'ddp_backend', None) != "fully_sharded"
|
and getattr(args, 'ddp_backend', None) != "fully_sharded"
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
assert args.fp16_no_flatten_grads, "If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm"
|
assert args.fp16_no_flatten_grads, \
|
||||||
|
"If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm"
|
||||||
|
|
||||||
args.ddp_rank = distributed_utils.get_data_parallel_rank()
|
args.ddp_rank = distributed_utils.get_data_parallel_rank()
|
||||||
|
|
||||||
|
@ -294,6 +299,7 @@ class LMDecoder(Decoder, FairseqIncrementalDecoder):
|
||||||
result = incremental_state[module][key].index_select(0, new_order)
|
result = incremental_state[module][key].index_select(0, new_order)
|
||||||
incremental_state[module][key] = result
|
incremental_state[module][key] = result
|
||||||
|
|
||||||
|
|
||||||
@register_model_architecture("lm", "lm_base")
|
@register_model_architecture("lm", "lm_base")
|
||||||
def base_lm_architecture(args):
|
def base_lm_architecture(args):
|
||||||
# backward compatibility for older model checkpoints
|
# backward compatibility for older model checkpoints
|
||||||
|
@ -357,4 +363,3 @@ def base_lm_architecture(args):
|
||||||
args.offload_activations = getattr(args, "offload_activations", False)
|
args.offload_activations = getattr(args, "offload_activations", False)
|
||||||
if args.offload_activations:
|
if args.offload_activations:
|
||||||
args.checkpoint_activations = True
|
args.checkpoint_activations = True
|
||||||
|
|
||||||
|
|
|
@ -6,33 +6,20 @@
|
||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import functools
|
from typing import Dict, List, Optional, Tuple
|
||||||
import math
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
from fairseq import utils
|
from fairseq import utils
|
||||||
from fairseq.distributed import utils as dist_utils, fsdp_wrap
|
from fairseq.distributed import utils as fsdp_wrap
|
||||||
from fairseq import distributed_utils
|
from fairseq import distributed_utils
|
||||||
from fairseq import checkpoint_utils
|
|
||||||
from fairseq.models import (
|
from fairseq.models import (
|
||||||
FairseqEncoder,
|
FairseqEncoder,
|
||||||
FairseqEncoderDecoderModel,
|
FairseqEncoderDecoderModel,
|
||||||
FairseqIncrementalDecoder,
|
|
||||||
register_model,
|
register_model,
|
||||||
register_model_architecture,
|
register_model_architecture,
|
||||||
)
|
)
|
||||||
from fairseq.models.transformer import Embedding
|
from fairseq.models.transformer import Embedding
|
||||||
from fairseq.modules import (
|
from fairseq.modules import PositionalEmbedding
|
||||||
AdaptiveSoftmax,
|
|
||||||
FairseqDropout,
|
|
||||||
LayerDropModuleList,
|
|
||||||
LayerNorm,
|
|
||||||
PositionalEmbedding,
|
|
||||||
SinusoidalPositionalEmbedding,
|
|
||||||
)
|
|
||||||
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
|
|
||||||
from torchscale.architecture.encoder import Encoder
|
from torchscale.architecture.encoder import Encoder
|
||||||
from torchscale.architecture.config import EncoderConfig, DecoderConfig
|
from torchscale.architecture.config import EncoderConfig, DecoderConfig
|
||||||
from .language_modeling import LMDecoder as MTDecoder
|
from .language_modeling import LMDecoder as MTDecoder
|
||||||
|
@ -164,18 +151,26 @@ class TranslationModel(FairseqEncoderDecoderModel):
|
||||||
help="Use FP32 computations in MoE top2 gating function")
|
help="Use FP32 computations in MoE top2 gating function")
|
||||||
parser.add_argument('--moe-second-expert-policy', type=str, default='sampling',
|
parser.add_argument('--moe-second-expert-policy', type=str, default='sampling',
|
||||||
help="policy for second expert, options: all/sampling/random")
|
help="policy for second expert, options: all/sampling/random")
|
||||||
parser.add_argument('--moe-normalize-gate-prob-before-dropping', default=False, action='store_true',
|
parser.add_argument(
|
||||||
help="whether to normalize gate probs before or after dropping experts for capacity and randomization")
|
'--moe-normalize-gate-prob-before-dropping', default=False, action='store_true',
|
||||||
|
help=(
|
||||||
|
"whether to normalize gate probs before or after dropping experts "
|
||||||
|
"for capacity and randomization"
|
||||||
|
)
|
||||||
|
)
|
||||||
parser.add_argument('--moe-expert-ffn-dim', type=int, default=0,
|
parser.add_argument('--moe-expert-ffn-dim', type=int, default=0,
|
||||||
help="MoE Expert FFN dimension")
|
help="MoE Expert FFN dimension")
|
||||||
parser.add_argument('--moe-top1-expert', default=False, action='store_true',
|
parser.add_argument('--moe-top1-expert', default=False, action='store_true',
|
||||||
help="Use top1 gate instead of top2")
|
help="Use top1 gate instead of top2")
|
||||||
parser.add_argument('--moe-eval-capacity-token-fraction', type=float, default=0.25,
|
parser.add_argument(
|
||||||
help="Fraction of tokens as capacity during validation" + \
|
'--moe-eval-capacity-token-fraction', type=float, default=0.25,
|
||||||
"if set to negative, use same as training. range: (0.0, 1.0].")
|
help=(
|
||||||
|
"Fraction of tokens as capacity during validation"
|
||||||
|
"if set to negative, use same as training. range: (0.0, 1.0]."
|
||||||
|
)
|
||||||
|
)
|
||||||
parser.add_argument('--moe-normalize-expert-grad', type=str, default='world_size',
|
parser.add_argument('--moe-normalize-expert-grad', type=str, default='world_size',
|
||||||
help="Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'")
|
help="Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'")
|
||||||
|
|
||||||
parser.add_argument('--use-moe-pad-mask', default=False, action='store_true',
|
parser.add_argument('--use-moe-pad-mask', default=False, action='store_true',
|
||||||
help="Don't route padding tokens to any expert")
|
help="Don't route padding tokens to any expert")
|
||||||
parser.add_argument('--use-xmoe', default=False, action='store_true',
|
parser.add_argument('--use-xmoe', default=False, action='store_true',
|
||||||
|
@ -395,6 +390,7 @@ class MTEncoder(Encoder, FairseqEncoder):
|
||||||
def max_positions(self):
|
def max_positions(self):
|
||||||
return self.embed_positions.max_positions
|
return self.embed_positions.max_positions
|
||||||
|
|
||||||
|
|
||||||
@register_model_architecture("mt", "mt_base")
|
@register_model_architecture("mt", "mt_base")
|
||||||
def base_architecture(args):
|
def base_architecture(args):
|
||||||
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
|
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
|
||||||
|
|
|
@ -1,14 +1,11 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
# Copyright (c) 2022 Microsoft
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
import math
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import torch
|
import torch
|
||||||
from infinibatch.iterators import CheckpointableIterator
|
from infinibatch.iterators import CheckpointableIterator
|
||||||
from . import utils
|
from . import utils
|
||||||
|
|
||||||
|
|
||||||
class BaseBatchGen(CheckpointableIterator):
|
class BaseBatchGen(CheckpointableIterator):
|
||||||
"""
|
"""
|
||||||
This is a base class for batch generators that use infinibatch
|
This is a base class for batch generators that use infinibatch
|
||||||
|
|
|
@ -1,13 +1,8 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
# Copyright (c) 2022 Microsoft
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
import glob
|
|
||||||
import os
|
import os
|
||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import time
|
|
||||||
import json
|
|
||||||
import random
|
|
||||||
import itertools
|
import itertools
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
|
@ -135,7 +130,6 @@ class MLMLoader(BaseBatchGen):
|
||||||
|
|
||||||
return tokenized_lines
|
return tokenized_lines
|
||||||
|
|
||||||
|
|
||||||
def _batchify(self, lines):
|
def _batchify(self, lines):
|
||||||
|
|
||||||
if self.max_sentences is not None:
|
if self.max_sentences is not None:
|
||||||
|
@ -145,7 +139,8 @@ class MLMLoader(BaseBatchGen):
|
||||||
else:
|
else:
|
||||||
def dynamic_batch_size(sample):
|
def dynamic_batch_size(sample):
|
||||||
lengths = [len(x) for x in sample]
|
lengths = [len(x) for x in sample]
|
||||||
batch_size = self.max_tokens // max(lengths) // self.required_batch_size_multiple * self.required_batch_size_multiple
|
batch_size = self.max_tokens // max(lengths)
|
||||||
|
batch_size = batch_size // self.required_batch_size_multiple * self.required_batch_size_multiple
|
||||||
return max(1, batch_size)
|
return max(1, batch_size)
|
||||||
|
|
||||||
batches = iterators.BucketedReadaheadBatchIterator(
|
batches = iterators.BucketedReadaheadBatchIterator(
|
||||||
|
@ -207,7 +202,7 @@ class MLMLoader(BaseBatchGen):
|
||||||
|
|
||||||
def _mask_lm(self, _random, doc):
|
def _mask_lm(self, _random, doc):
|
||||||
def mask_tokens():
|
def mask_tokens():
|
||||||
return f"<mask>"
|
return "<mask>"
|
||||||
|
|
||||||
length = len(doc)
|
length = len(doc)
|
||||||
mask_tokens_num = int(length * self.args.mask_prob)
|
mask_tokens_num = int(length * self.args.mask_prob)
|
||||||
|
|
|
@ -1,14 +1,13 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
# Copyright (c) 2022 Microsoft
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
import os
|
|
||||||
import gzip
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union
|
from typing import Dict, Iterable, Optional
|
||||||
import collections
|
import collections
|
||||||
from infinibatch import iterators
|
from infinibatch import iterators
|
||||||
|
|
||||||
|
|
||||||
def apply_to_sample(f, sample):
|
def apply_to_sample(f, sample):
|
||||||
if hasattr(sample, "__len__") and len(sample) == 0:
|
if hasattr(sample, "__len__") and len(sample) == 0:
|
||||||
return {}
|
return {}
|
||||||
|
@ -34,6 +33,7 @@ def apply_to_sample(f, sample):
|
||||||
|
|
||||||
return _apply(sample)
|
return _apply(sample)
|
||||||
|
|
||||||
|
|
||||||
class NativeCheckpointableIterator(iterators.CheckpointableIterator):
|
class NativeCheckpointableIterator(iterators.CheckpointableIterator):
|
||||||
def __init__(self, iterable: Iterable):
|
def __init__(self, iterable: Iterable):
|
||||||
self._input_iterable = iterable
|
self._input_iterable = iterable
|
||||||
|
@ -44,7 +44,10 @@ class NativeCheckpointableIterator(iterators.CheckpointableIterator):
|
||||||
|
|
||||||
def setstate(self, checkpoint: Optional[Dict]):
|
def setstate(self, checkpoint: Optional[Dict]):
|
||||||
self._iterator = iter(self._input_iterable)
|
self._iterator = iter(self._input_iterable)
|
||||||
self._num_items_yielded = iterators._advance_iterator(self._iterator, checkpoint['num_items_yielded']) if checkpoint is not None else 0
|
self._num_items_yielded = iterators._advance_iterator(
|
||||||
|
self._iterator,
|
||||||
|
checkpoint['num_items_yielded']
|
||||||
|
) if checkpoint is not None else 0
|
||||||
|
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
item = next(self._iterator)
|
item = next(self._iterator)
|
||||||
|
|
|
@ -10,16 +10,13 @@ import logging
|
||||||
import os
|
import os
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
import json
|
import json
|
||||||
from omegaconf import MISSING, II, OmegaConf
|
from omegaconf import MISSING, II
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from fairseq import utils
|
from fairseq import utils
|
||||||
from fairseq.data import Dictionary
|
from fairseq.data import Dictionary
|
||||||
from fairseq.tasks import FairseqTask, register_task
|
from fairseq.tasks import FairseqTask, register_task
|
||||||
from .data.mlm_loader import MLMLoader
|
from .data.mlm_loader import MLMLoader
|
||||||
from fairseq.dataclass import FairseqDataclass, ChoiceEnum
|
from fairseq.dataclass import FairseqDataclass, ChoiceEnum
|
||||||
from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -27,6 +24,7 @@ logger = logging.getLogger(__name__)
|
||||||
SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"])
|
SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"])
|
||||||
SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"])
|
SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"])
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PretrainingConfig(FairseqDataclass):
|
class PretrainingConfig(FairseqDataclass):
|
||||||
data: str = field(
|
data: str = field(
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
# Copyright (c) 2022 Microsoft
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
|
# flake8: noqa
|
||||||
import models
|
import models
|
||||||
import tasks
|
import tasks
|
||||||
|
|
||||||
from fairseq_cli.train import cli_main
|
from fairseq_cli.train import cli_main
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cli_main()
|
cli_main()
|
|
@ -7,6 +7,7 @@ from fairseq.utils import multi_tensor_l2norm_available, multi_tensor_total_norm
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def clip_grad_norm_(params, max_norm, moe_expert_count, aggregate_norm_fn=None) -> torch.Tensor:
|
def clip_grad_norm_(params, max_norm, moe_expert_count, aggregate_norm_fn=None) -> torch.Tensor:
|
||||||
def grad_exists(p):
|
def grad_exists(p):
|
||||||
|
|
|
@ -23,6 +23,7 @@ testcases = [
|
||||||
{"fsdp": True}
|
{"fsdp": True}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("args", testcases)
|
@pytest.mark.parametrize("args", testcases)
|
||||||
def test_decoder(args):
|
def test_decoder(args):
|
||||||
config = DecoderConfig(**args)
|
config = DecoderConfig(**args)
|
||||||
|
|
|
@ -23,6 +23,7 @@ testcases = [
|
||||||
{"fsdp": True}
|
{"fsdp": True}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("args", testcases)
|
@pytest.mark.parametrize("args", testcases)
|
||||||
def test_encoder(args):
|
def test_encoder(args):
|
||||||
config = EncoderConfig(**args)
|
config = EncoderConfig(**args)
|
||||||
|
|
|
@ -25,6 +25,7 @@ testcases = [
|
||||||
{"fsdp": True}
|
{"fsdp": True}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("args", testcases)
|
@pytest.mark.parametrize("args", testcases)
|
||||||
def test_decoder(args):
|
def test_decoder(args):
|
||||||
config = EncoderDecoderConfig(**args)
|
config = EncoderDecoderConfig(**args)
|
||||||
|
|
|
@ -355,7 +355,8 @@ def top2gating(
|
||||||
if has_tutel:
|
if has_tutel:
|
||||||
locations1_s = torch.sum(locations1 * mask1_, dim=1)
|
locations1_s = torch.sum(locations1 * mask1_, dim=1)
|
||||||
locations2_s = torch.sum(locations2 * mask2_, dim=1)
|
locations2_s = torch.sum(locations2 * mask2_, dim=1)
|
||||||
return l_aux, metadata, capacity, num_experts, [indices1_s, indices2_s], [locations1_s, locations2_s], [gates1_s, gates2_s]
|
return l_aux, metadata, capacity, num_experts, \
|
||||||
|
[indices1_s, indices2_s], [locations1_s, locations2_s], [gates1_s, gates2_s]
|
||||||
|
|
||||||
# Store the capacity location for each token
|
# Store the capacity location for each token
|
||||||
locations1_s = torch.sum(locations1 * mask1, dim=1)
|
locations1_s = torch.sum(locations1 * mask1, dim=1)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user