Remove inheritance from object
This commit is contained in:
parent
2b101355d7
commit
7bfdad13f8
|
@ -62,7 +62,7 @@ class NativeCheckpointableIterator(iterators.CheckpointableIterator):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class WeightIterator(object):
|
class WeightIterator:
|
||||||
def __init__(self, weights, seed):
|
def __init__(self, weights, seed):
|
||||||
self.weights = weights
|
self.weights = weights
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
|
|
||||||
class EncoderConfig(object):
|
class EncoderConfig:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768)
|
self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768)
|
||||||
self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12)
|
self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12)
|
||||||
|
@ -71,7 +71,7 @@ class EncoderConfig(object):
|
||||||
self.__dict__[hp] = getattr(args, hp, None)
|
self.__dict__[hp] = getattr(args, hp, None)
|
||||||
|
|
||||||
|
|
||||||
class DecoderConfig(object):
|
class DecoderConfig:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
|
self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
|
||||||
self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12)
|
self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12)
|
||||||
|
@ -135,7 +135,7 @@ class DecoderConfig(object):
|
||||||
self.__dict__[hp] = getattr(args, hp, None)
|
self.__dict__[hp] = getattr(args, hp, None)
|
||||||
|
|
||||||
|
|
||||||
class EncoderDecoderConfig(object):
|
class EncoderDecoderConfig:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768)
|
self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768)
|
||||||
self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12)
|
self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12)
|
||||||
|
|
|
@ -13,7 +13,7 @@ except ModuleNotFoundError:
|
||||||
from .xmoe.global_groups import get_moe_group
|
from .xmoe.global_groups import get_moe_group
|
||||||
|
|
||||||
|
|
||||||
class set_torch_seed(object):
|
class set_torch_seed:
|
||||||
def __init__(self, seed):
|
def __init__(self, seed):
|
||||||
assert isinstance(seed, int)
|
assert isinstance(seed, int)
|
||||||
self.rng_state = self.get_rng_state()
|
self.rng_state = self.get_rng_state()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user