crammed in DAdaptation (doesn't seem worth it) and ScheduleFree (forgot I wanted to weeks ago, seems promising), optimization wrapper cleanup, test trainer changes, etc.
This commit is contained in:
parent
c6e0f905b5
commit
0d5d545a40
|
@ -306,11 +306,14 @@ class Hyperparameters:
|
|||
|
||||
optimizer: str = "Adamw"
|
||||
torch_optimizer: bool = False
|
||||
|
||||
optimizer_params: dict = field(default_factory=lambda: {})
|
||||
learning_rate: float = 3.25e-4
|
||||
|
||||
scheduler_type: str = ""
|
||||
scheduler: str = ""
|
||||
scheduler_type: str = "" # deprecated
|
||||
scheduler_params: dict = field(default_factory=lambda: {})
|
||||
torch_scheduler: bool = False
|
||||
|
||||
@dataclass()
|
||||
class Evaluation:
|
||||
|
@ -337,7 +340,7 @@ class DeepSpeed:
|
|||
for k in cfg.hyperparameters.scheduler_params:
|
||||
scheduler_params[k] = cfg.hyperparameters.scheduler_params[k]
|
||||
|
||||
if cfg.hyperparameters.scheduler_type == "WarmupDecayLR" and 'total_num_steps' not in scheduler_params:
|
||||
if cfg.hyperparameters.scheduler == "WarmupDecayLR" and 'total_num_steps' not in scheduler_params:
|
||||
scheduler_params['total_num_steps'] = cfg.trainer.iterations
|
||||
|
||||
ds_cfg = {
|
||||
|
@ -350,9 +353,9 @@ class DeepSpeed:
|
|||
}
|
||||
} if not cfg.hyperparameters.torch_optimizer else None,
|
||||
"scheduler": {
|
||||
"type": cfg.hyperparameters.scheduler_type,
|
||||
"type": cfg.hyperparameters.scheduler,
|
||||
"params": scheduler_params,
|
||||
} if cfg.hyperparameters.scheduler_type != "" else None,
|
||||
} if not cfg.hyperparameters.torch_scheduler else None,
|
||||
"gradient_clipping": cfg.hyperparameters.gradient_clipping,
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
|
@ -544,15 +547,17 @@ class Inference:
|
|||
# should be renamed to optimizations
|
||||
@dataclass()
|
||||
class Optimizations:
|
||||
bitsandbytes: bool = False
|
||||
injects: bool = False
|
||||
replace: bool = False
|
||||
injects: bool = False # overwrites default torch classes (not recommended)
|
||||
replace: bool = False # replaces modules in place with the optimized version (recommended)
|
||||
|
||||
linear: bool = True
|
||||
embedding: bool = True
|
||||
linear: bool = True # inject/replace linear for BnB
|
||||
embedding: bool = True # inject/replace embedding for BnB
|
||||
optimizers: bool = True # inject/replace optimizers (BnB, DAdaptation)
|
||||
|
||||
bitnet: bool = False
|
||||
fp8: bool = False
|
||||
bitsandbytes: bool = False # use bitsandbytes
|
||||
dadaptation: bool = True # use dadaptation optimizer
|
||||
bitnet: bool = False # use bitnet
|
||||
fp8: bool = False # use fp8
|
||||
|
||||
@dataclass()
|
||||
class Config(_Config):
|
||||
|
@ -636,6 +641,17 @@ class Config(_Config):
|
|||
else:
|
||||
self.optimizations = Optimizations(**self.optimizations)
|
||||
|
||||
if self.hyperparameters.scheduler_type and not self.hyperparameters.scheduler:
|
||||
self.hyperparameters.scheduler = self.hyperparameters.scheduler_type
|
||||
self.hyperparameters.scheduler_type = ""
|
||||
|
||||
# do not combine the two
|
||||
if self.hyperparameters.scheduler == "schedulefree" and self.optimizations.dadaptation:
|
||||
self.hyperparameters.scheduler = ""
|
||||
|
||||
if self.hyperparameters.scheduler == "":
|
||||
self.hyperparameters.torch_scheduler = True
|
||||
|
||||
# Preserves the old behavior
|
||||
class NaiveTokenizer:
|
||||
def get_vocab( self ):
|
||||
|
|
|
@ -379,6 +379,11 @@ class Dataset(_Dataset):
|
|||
path = random.choice(choices)
|
||||
if cfg.dataset.use_hdf5:
|
||||
key = _get_hdf5_path(path)
|
||||
|
||||
if "audio" not in cfg.hdf5[key]:
|
||||
_logger.warning("MISSING AUDIO:", key)
|
||||
continue
|
||||
|
||||
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
|
||||
else:
|
||||
qnt = _load_quants(path)
|
||||
|
@ -763,15 +768,15 @@ def create_dataset_metadata( skip_existing=True ):
|
|||
name = str(dir)
|
||||
name = name.replace(root, "")
|
||||
|
||||
# yucky
|
||||
speaker_name = name
|
||||
if "LbriTTS-R" in speaker_name:
|
||||
speaker_name = speaker_name.replace("LbriTTS-R", "LibriVox")
|
||||
|
||||
metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
|
||||
metadata_path.parents[0].mkdir(parents=True, exist_ok=True)
|
||||
|
||||
metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read())
|
||||
try:
|
||||
metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read())
|
||||
except Exception as e:
|
||||
metadata = {}
|
||||
|
||||
if not os.path.isdir(f'{root}/{name}/'):
|
||||
return
|
||||
|
@ -872,8 +877,8 @@ def create_dataset_hdf5( skip_existing=True ):
|
|||
|
||||
# yucky
|
||||
speaker_name = name
|
||||
if "LbriTTS-R" in speaker_name:
|
||||
speaker_name = speaker_name.replace("LbriTTS-R", "LibriVox")
|
||||
if "LibriTTS-R" in speaker_name:
|
||||
speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox")
|
||||
|
||||
metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
|
||||
metadata_path.parents[0].mkdir(parents=True, exist_ok=True)
|
||||
|
@ -899,10 +904,8 @@ def create_dataset_hdf5( skip_existing=True ):
|
|||
|
||||
key = f'{type}/{speaker_name}/{id}'
|
||||
|
||||
"""
|
||||
if skip_existing and key in hf:
|
||||
continue
|
||||
"""
|
||||
|
||||
group = hf.create_group(key) if key not in hf else hf[key]
|
||||
|
||||
|
|
|
@ -143,7 +143,7 @@ def _load_vocos_model(device="cuda", levels=cfg.model.max_levels):
|
|||
|
||||
@cache
|
||||
def _load_dac_model(device="cuda", levels=cfg.model.max_levels):
|
||||
kwargs = dict(model_type="24khz",model_bitrate="8kbps",tag="latest")
|
||||
kwargs = dict(model_type="44khz",model_bitrate="8kbps",tag="latest")
|
||||
"""
|
||||
if not cfg.variable_sample_rate:
|
||||
# yes there's a better way, something like f'{cfg.sample.rate//1000}hz'
|
||||
|
|
|
@ -45,11 +45,16 @@ def load_engines(training=True):
|
|||
if inferencing:
|
||||
model._cfg.training = False
|
||||
|
||||
if (cfg.optimizations.bitsandbytes and cfg.optimizations.replace) or (cfg.optimizations.fp8):
|
||||
if cfg.optimizations.replace and cfg.optimizations.linear:
|
||||
model.model = ml.replace_linear( model.model )
|
||||
|
||||
if cfg.optimizations.replace and cfg.optimizations.embedding:
|
||||
model.model = ml.replace_embedding( model.model )
|
||||
|
||||
if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
|
||||
optimizer_class = None
|
||||
scheduler_class = None
|
||||
|
||||
params = {
|
||||
"lr": cfg.hyperparameters.learning_rate,
|
||||
}
|
||||
|
@ -58,6 +63,10 @@ def load_engines(training=True):
|
|||
params["eps"] = 1e-07
|
||||
params["weight_decay"] = 0.01
|
||||
|
||||
# for dadaptation since it has Adam only
|
||||
if ml.AdamW == ml.Adam:
|
||||
params["decouple"] = True
|
||||
|
||||
optimizer_class = ml.AdamW
|
||||
elif cfg.hyperparameters.optimizer.lower() == "sgd":
|
||||
optimizer = ml.SGD
|
||||
|
@ -72,11 +81,27 @@ def load_engines(training=True):
|
|||
raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
|
||||
|
||||
params.update(cfg.hyperparameters.optimizer_params)
|
||||
|
||||
optimizer = optimizer_class(
|
||||
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
|
||||
**params,
|
||||
)
|
||||
|
||||
if cfg.hyperparameters.scheduler.lower() == "schedulefree":
|
||||
if cfg.hyperparameters.optimizer.lower() == "adamw":
|
||||
scheduler_class = ml.schedulefree.AdamWScheduleFree
|
||||
elif cfg.hyperparameters.optimizer.lower() == "sgd":
|
||||
scheduler_class = ml.schedulefree.SGDScheduleFree
|
||||
else:
|
||||
raise ValueError(f'ScheduleFree not implemented with requested optimizer: {cfg.hyperparameters.optimizer}')
|
||||
|
||||
optimizer = scheduler_class(
|
||||
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
|
||||
lr = params['lr']
|
||||
)
|
||||
|
||||
|
||||
|
||||
# set up our LR scheduler here
|
||||
|
||||
if inferencing:
|
||||
|
|
|
@ -365,7 +365,7 @@ def example_usage():
|
|||
'n_tokens': 1024,
|
||||
'd_model': 1024, # 256, # 1024, # 1536
|
||||
'n_heads': 16, # 4, # 16, # 24
|
||||
'n_layers': 12, # 32
|
||||
'n_layers': 4, # 32
|
||||
'n_experts': 1,
|
||||
|
||||
'l_padding': 8 if cfg.optimizations.fp8 else 0,
|
||||
|
@ -381,16 +381,66 @@ def example_usage():
|
|||
"""
|
||||
|
||||
model = AR_NAR(**kwargs).to(device)
|
||||
steps = 100
|
||||
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
|
||||
#optimizer = ml.Adagrad(model.parameters(), lr=1.0e-2)
|
||||
#optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
|
||||
steps = 1000
|
||||
|
||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy"
|
||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else ""
|
||||
learning_rate = cfg.hyperparameters.learning_rate if cfg.cfg_path is not None else None
|
||||
|
||||
if cfg.optimizations.dadaptation:
|
||||
# do not combine the two
|
||||
if scheduler == "schedulefree":
|
||||
scheduler = ""
|
||||
|
||||
learning_rate = 1.0
|
||||
|
||||
if optimizer == "prodigy":
|
||||
if learning_rate is None:
|
||||
learning_rate = 1.0
|
||||
|
||||
optimizer = ml.Prodigy
|
||||
elif optimizer == "adagrad":
|
||||
if learning_rate is None:
|
||||
learning_rate = 1.0e-2
|
||||
|
||||
optimizer = ml.Adagrad
|
||||
elif optimizer == "adamw":
|
||||
if learning_rate is None:
|
||||
learning_rate = 1.0e-4
|
||||
|
||||
optimizer = ml.AdamW
|
||||
elif optimizer == "sdg":
|
||||
if learning_rate is None:
|
||||
learning_rate = 1.0e-4
|
||||
|
||||
optimizer = ml.SGD
|
||||
else:
|
||||
raise ValueError(f"Unrecognized optimizer: {optimizer}")
|
||||
|
||||
print("Optimizer:", optimizer, "\tLearning rate:", learning_rate)
|
||||
|
||||
optimizer = optimizer(model.parameters(), lr=learning_rate)
|
||||
|
||||
if scheduler == "schedulefree":
|
||||
if isinstance(optimizer, ml.AdamW):
|
||||
scheduler = ml.schedulefree.AdamWScheduleFree
|
||||
elif isinstance(optimizer, ml.SGD):
|
||||
scheduler = ml.schedulefree.SGDScheduleFree
|
||||
else:
|
||||
scheduler = None
|
||||
|
||||
if scheduler is not None:
|
||||
print("Scheduler:", scheduler)
|
||||
optimizer = scheduler( model.parameters(), lr = learning_rate )
|
||||
|
||||
if cfg.optimizations.replace and cfg.optimizations.linear:
|
||||
model = ml.replace_linear( model )
|
||||
|
||||
if cfg.optimizations.replace and cfg.optimizations.embedding:
|
||||
model = ml.replace_embedding( model )
|
||||
|
||||
engine = Engine(model=model, optimizer=optimizer)
|
||||
|
||||
if (cfg.optimizations.bitsandbytes and cfg.optimizations.replace) or (cfg.optimizations.fp8):
|
||||
model.model = ml.replace_linear( model.model )
|
||||
|
||||
torch.save( {
|
||||
'module': model.state_dict()
|
||||
}, "./data/test.pth" )
|
||||
|
|
|
@ -16,6 +16,8 @@ from torch.nn.utils.rnn import pad_sequence
|
|||
from torch.utils.checkpoint import checkpoint
|
||||
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
|
||||
|
||||
from ..utils import wrapper as ml
|
||||
|
||||
from ..samplers import reptition_penalize, length_penalize, top_k_top_p_filtering, dynamic_temperature, top_k_logits_list, mirostat_sample
|
||||
|
||||
try:
|
||||
|
@ -191,48 +193,9 @@ try:
|
|||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
LLAMA_ATTENTIONS["xformers"] = LLamaXformersAttention
|
||||
|
||||
except Exception as e:
|
||||
print("Error creating `LLamaXformersAttention`:", e)
|
||||
|
||||
def replace_attention( model, impl, verbose=False ):
|
||||
device = next(model.parameters()).device
|
||||
dtype = next(model.parameters()).dtype
|
||||
attentions = [k.split('.') for k, m in model.named_modules() if isinstance(m, LlamaAttention)]
|
||||
|
||||
if impl not in LLAMA_ATTENTIONS:
|
||||
print(f"Attention '{imp} is not in LLAMA_ATTENTIONS'")
|
||||
return model
|
||||
|
||||
klass = LLAMA_ATTENTIONS[impl]
|
||||
|
||||
for *parent, k in attentions:
|
||||
name = '.'.join(parent)
|
||||
|
||||
# copy parameters
|
||||
m = getattr( model.get_submodule(name), k )
|
||||
|
||||
if isinstance(m, klass):
|
||||
continue
|
||||
|
||||
config = m.config
|
||||
layer_idx = m.layer_idx
|
||||
|
||||
kwargs = dict(config=config, layer_idx=layer_idx)
|
||||
|
||||
# overwrite
|
||||
setattr(
|
||||
model.get_submodule(name), k,
|
||||
klass( **kwargs ).to(device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
if verbose:
|
||||
print(f"Replacing {name}.{k} to", klass)
|
||||
|
||||
return model
|
||||
|
||||
def _create_mask(l, device):
|
||||
"""1 is valid region and 0 is invalid."""
|
||||
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
|
||||
|
@ -485,6 +448,14 @@ class Base(nn.Module):
|
|||
|
||||
self.sep = nn.Parameter(torch.randn(d_model))
|
||||
|
||||
# ick, there has to be a better way
|
||||
attention = self.config.attention if self.config is not None else None
|
||||
use_xformers = False
|
||||
|
||||
if attention == "xformers":
|
||||
use_xformers = True
|
||||
attention = None
|
||||
|
||||
if self.arch_type == "transformer":
|
||||
self.sin_emb = SinusoidalEmbedding(d_model)
|
||||
self.blocks = nn.ModuleList([TransformerBlock(
|
||||
|
@ -495,7 +466,7 @@ class Base(nn.Module):
|
|||
norm_type=self.norm_type,
|
||||
n_levels=self.n_resp_levels,
|
||||
) for _ in range(n_layers) ])
|
||||
elif self.arch_type == "mistral" or self.arch_type == "mixtral":
|
||||
elif self.arch_type in ["mistral", "mixtral"]:
|
||||
if n_experts <= 1:
|
||||
self.model = MistralModel(MistralConfig(
|
||||
vocab_size=n_resp_tokens,
|
||||
|
@ -509,7 +480,7 @@ class Base(nn.Module):
|
|||
hidden_act="gelu",
|
||||
is_encoder_decoder=False,
|
||||
is_decoder=True,
|
||||
attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2",
|
||||
attn_implementation=attention,
|
||||
))
|
||||
else:
|
||||
self.model = MixtralModel(MixtralConfig(
|
||||
|
@ -528,18 +499,10 @@ class Base(nn.Module):
|
|||
is_decoder=True,
|
||||
num_local_experts=n_experts,
|
||||
num_experts_per_tok=min(2, n_experts),
|
||||
attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2",
|
||||
attn_implementation=attention,
|
||||
))
|
||||
elif self.arch_type == "llama":
|
||||
if n_experts <= 1:
|
||||
# ick, there has to be a better way
|
||||
attention = self.config.attention if self.config is not None else None # "flash_attention_2",
|
||||
use_xformers = False
|
||||
|
||||
if attention == "xformers":
|
||||
use_xformers = True
|
||||
attention = None
|
||||
|
||||
self.model = LlamaModel(LlamaConfig(
|
||||
vocab_size=n_resp_tokens,
|
||||
hidden_size=d_model,
|
||||
|
@ -555,9 +518,6 @@ class Base(nn.Module):
|
|||
is_decoder=True,
|
||||
attn_implementation=attention,
|
||||
))
|
||||
|
||||
if use_xformers:
|
||||
self.model = replace_attention( self.model, "xformers" if use_xformers else attention )
|
||||
else:
|
||||
self.model = MixtralModel(MixtralConfig(
|
||||
vocab_size =n_resp_tokens,
|
||||
|
@ -575,9 +535,8 @@ class Base(nn.Module):
|
|||
is_decoder=True,
|
||||
num_local_experts=n_experts,
|
||||
num_experts_per_tok=min(2, n_experts),
|
||||
attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2",
|
||||
attn_implementation=attention,
|
||||
))
|
||||
|
||||
elif self.arch_type == "retnet":
|
||||
kwargs = dict(
|
||||
vocab_size=n_resp_tokens,
|
||||
|
@ -589,9 +548,9 @@ class Base(nn.Module):
|
|||
dropout=p_dropout if training else 0.0,
|
||||
checkpoint_activations=self.activation_checkpointing,
|
||||
activation_fn="gelu",
|
||||
use_layernorm=True, # self.version < 3,
|
||||
use_biases=True, # self.version < 3,
|
||||
use_glu=False, # self.version >= 3,
|
||||
use_layernorm=self.version < 3,
|
||||
use_biases=self.version < 3,
|
||||
use_glu=self.version >= 3,
|
||||
|
||||
chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0,
|
||||
recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,
|
||||
|
@ -642,6 +601,9 @@ class Base(nn.Module):
|
|||
else:
|
||||
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
|
||||
|
||||
if use_xformers:
|
||||
self.model = ml.replace_attention( self.model, klass=LLamaXformersAttention, target=LlamaAttention )
|
||||
|
||||
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
||||
|
||||
self.accuracy_metric = MulticlassAccuracy(
|
||||
|
|
|
@ -9,6 +9,11 @@ from ..config import cfg
|
|||
Embedding = torch.nn.Embedding
|
||||
Linear = torch.nn.Linear
|
||||
|
||||
Adam = torch.optim.Adam
|
||||
AdamW = torch.optim.AdamW
|
||||
SGD = torch.optim.SGD
|
||||
Adagrad = torch.optim.Adagrad
|
||||
|
||||
# https://github.com/kyegomez/BitNet
|
||||
if cfg.optimizations.bitnet:
|
||||
from bitnet import BitLinear
|
||||
|
@ -37,19 +42,20 @@ if cfg.optimizations.bitsandbytes:
|
|||
)).to(self.weight.dtype) )
|
||||
"""
|
||||
|
||||
if cfg.optimizations.optimizers:
|
||||
Adam = bnb.optim.Adam8bit
|
||||
AdamW = bnb.optim.AdamW8bit
|
||||
SGD = bnb.optim.SGD8bit
|
||||
Adagrad = bnb.optim.Adagrad8bit
|
||||
|
||||
if cfg.optimizations.bitsandbytes:
|
||||
import bitsandbytes as bnb
|
||||
elif cfg.optimizations.dadaptation:
|
||||
import dadaptation
|
||||
|
||||
Adam = bnb.optim.Adam8bit
|
||||
AdamW = bnb.optim.AdamW8bit
|
||||
SGD = bnb.optim.SGD8bit
|
||||
Adagrad = bnb.optim.Adagrad8bit
|
||||
else:
|
||||
Adam = torch.optim.Adam
|
||||
AdamW = torch.optim.AdamW
|
||||
SGD = torch.optim.SGD
|
||||
Adagrad = torch.optim.Adagrad
|
||||
if cfg.optimizations.optimizers:
|
||||
Adam = dadaptation.DAdaptAdam
|
||||
AdamW = dadaptation.DAdaptAdam
|
||||
SGD = dadaptation.DAdaptSGD
|
||||
AdaGrad = dadaptation.DAdaptAdaGrad
|
||||
|
||||
# handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16)
|
||||
@contextmanager
|
||||
|
@ -92,42 +98,112 @@ else:
|
|||
def autocast():
|
||||
yield torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp)
|
||||
|
||||
if cfg.optimizations.injects and cfg.optimizations.bitsandbytes:
|
||||
torch.nn.Linear = Linear
|
||||
torch.nn.Embedding = Embedding
|
||||
if cfg.optimizations.injects:
|
||||
if cfg.optimizations.linear:
|
||||
torch.nn.Linear = Linear
|
||||
|
||||
if cfg.optimizations.embedding:
|
||||
torch.nn.Embedding = Embedding
|
||||
|
||||
torch.optim.Adam = Adam
|
||||
torch.optim.AdamW = AdamW
|
||||
torch.optim.SGD = SGD
|
||||
if cfg.optimizations.optimizers:
|
||||
torch.optim.Adam = Adam
|
||||
torch.optim.AdamW = AdamW
|
||||
torch.optim.SGD = SGD
|
||||
|
||||
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
|
||||
def replace_linear( model, verbose=False ):
|
||||
# generalizing this would be super sugoi but the there's no catch all for arguments
|
||||
def replace_linear( model, klass=Linear, target=torch.nn.Linear, verbose=False ):
|
||||
bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet
|
||||
|
||||
device = next(model.parameters()).device
|
||||
linears = [k.split('.') for k, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
|
||||
klass = Linear
|
||||
dtype = next(model.parameters()).dtype
|
||||
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
|
||||
|
||||
for *parent, k in linears:
|
||||
for *parent, k in modules:
|
||||
name = '.'.join(parent)
|
||||
|
||||
|
||||
# copy parameters
|
||||
m = getattr( model.get_submodule(name), k )
|
||||
|
||||
if isinstance(m, klass):
|
||||
continue
|
||||
|
||||
in_features = m.in_features
|
||||
out_features = m.out_features
|
||||
bias = m.bias is not None
|
||||
|
||||
kwargs = dict(in_features=in_features, out_features=out_features, bias=bias) if not bnb else dict(input_features=in_features, output_features=out_features, bias=bias)
|
||||
kwargs = dict(
|
||||
in_features = m.in_features,
|
||||
out_features = m.out_features,
|
||||
bias = m.bias is not None,
|
||||
) if not bnb else dict(
|
||||
input_features=m.in_features,
|
||||
output_features=m.out_features,
|
||||
bias=m.bias is not None,
|
||||
)
|
||||
|
||||
# overwrite
|
||||
setattr(
|
||||
model.get_submodule(name), k,
|
||||
klass( **kwargs ).to(device=device, dtype=cfg.trainer.dtype)
|
||||
klass( **kwargs ).to(device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
if verbose:
|
||||
print(f"Replacing {name}.{k} to", klass)
|
||||
|
||||
return model
|
||||
|
||||
def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbose=False ):
|
||||
device = next(model.parameters()).device
|
||||
dtype = next(model.parameters()).dtype
|
||||
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
|
||||
|
||||
for *parent, k in modules:
|
||||
name = '.'.join(parent)
|
||||
|
||||
m = getattr( model.get_submodule(name), k )
|
||||
|
||||
if isinstance(m, klass):
|
||||
continue
|
||||
|
||||
kwargs = dict(
|
||||
num_embeddings=m.num_embeddings,
|
||||
embedding_dim=m.embedding_dim,
|
||||
padding_idx=m.padding_idx,
|
||||
max_norm=m.max_norm,
|
||||
norm_type=m.norm_type,
|
||||
scale_grad_by_freq=m.scale_grad_by_freq,
|
||||
sparse=m.sparse,
|
||||
)
|
||||
|
||||
# overwrite
|
||||
setattr(
|
||||
model.get_submodule(name), k,
|
||||
klass( **kwargs ).to(device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
if verbose:
|
||||
print(f"Replacing {name}.{k} to", klass)
|
||||
|
||||
return model
|
||||
|
||||
# cannot feasibly do default arguments here sad
|
||||
def replace_attention( model, klass, target, verbose=False ):
|
||||
device = next(model.parameters()).device
|
||||
dtype = next(model.parameters()).dtype
|
||||
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
|
||||
|
||||
for *parent, k in modules:
|
||||
name = '.'.join(parent)
|
||||
|
||||
m = getattr( model.get_submodule(name), k )
|
||||
|
||||
if isinstance(m, klass):
|
||||
continue
|
||||
|
||||
kwargs = dict(
|
||||
config = m.config,
|
||||
layer_idx = m.layer_idx,
|
||||
)
|
||||
# overwrite
|
||||
setattr(
|
||||
model.get_submodule(name), k,
|
||||
klass( **kwargs ).to(device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
if verbose:
|
||||
|
@ -139,4 +215,12 @@ def replace_linear( model, verbose=False ):
|
|||
try:
|
||||
from prodigyopt import Prodigy
|
||||
except Exception as e:
|
||||
print('Error while importing Prodigyopt:', str(e))
|
||||
pass
|
||||
|
||||
# https://github.com/facebookresearch/schedule_free/
|
||||
try:
|
||||
import schedulefree
|
||||
except Exception as e:
|
||||
print('Error while importing Schedule_Free:', str(e))
|
||||
pass
|
Loading…
Reference in New Issue
Block a user