crammed in DAdaptation (doesn't seem worth it) and ScheduleFree (forgot I wanted to weeks ago, seems promising), optimization wrapper cleanup, test trainer changes, etc.

This commit is contained in:
mrq 2024-05-09 20:28:20 -05:00
parent c6e0f905b5
commit 0d5d545a40
7 changed files with 256 additions and 116 deletions

View File

@ -306,11 +306,14 @@ class Hyperparameters:
optimizer: str = "Adamw" optimizer: str = "Adamw"
torch_optimizer: bool = False torch_optimizer: bool = False
optimizer_params: dict = field(default_factory=lambda: {}) optimizer_params: dict = field(default_factory=lambda: {})
learning_rate: float = 3.25e-4 learning_rate: float = 3.25e-4
scheduler_type: str = "" scheduler: str = ""
scheduler_type: str = "" # deprecated
scheduler_params: dict = field(default_factory=lambda: {}) scheduler_params: dict = field(default_factory=lambda: {})
torch_scheduler: bool = False
@dataclass() @dataclass()
class Evaluation: class Evaluation:
@ -337,7 +340,7 @@ class DeepSpeed:
for k in cfg.hyperparameters.scheduler_params: for k in cfg.hyperparameters.scheduler_params:
scheduler_params[k] = cfg.hyperparameters.scheduler_params[k] scheduler_params[k] = cfg.hyperparameters.scheduler_params[k]
if cfg.hyperparameters.scheduler_type == "WarmupDecayLR" and 'total_num_steps' not in scheduler_params: if cfg.hyperparameters.scheduler == "WarmupDecayLR" and 'total_num_steps' not in scheduler_params:
scheduler_params['total_num_steps'] = cfg.trainer.iterations scheduler_params['total_num_steps'] = cfg.trainer.iterations
ds_cfg = { ds_cfg = {
@ -350,9 +353,9 @@ class DeepSpeed:
} }
} if not cfg.hyperparameters.torch_optimizer else None, } if not cfg.hyperparameters.torch_optimizer else None,
"scheduler": { "scheduler": {
"type": cfg.hyperparameters.scheduler_type, "type": cfg.hyperparameters.scheduler,
"params": scheduler_params, "params": scheduler_params,
} if cfg.hyperparameters.scheduler_type != "" else None, } if not cfg.hyperparameters.torch_scheduler else None,
"gradient_clipping": cfg.hyperparameters.gradient_clipping, "gradient_clipping": cfg.hyperparameters.gradient_clipping,
"fp16": { "fp16": {
"enabled": True, "enabled": True,
@ -544,15 +547,17 @@ class Inference:
# should be renamed to optimizations # should be renamed to optimizations
@dataclass() @dataclass()
class Optimizations: class Optimizations:
bitsandbytes: bool = False injects: bool = False # overwrites default torch classes (not recommended)
injects: bool = False replace: bool = False # replaces modules in place with the optimized version (recommended)
replace: bool = False
linear: bool = True linear: bool = True # inject/replace linear for BnB
embedding: bool = True embedding: bool = True # inject/replace embedding for BnB
optimizers: bool = True # inject/replace optimizers (BnB, DAdaptation)
bitnet: bool = False bitsandbytes: bool = False # use bitsandbytes
fp8: bool = False dadaptation: bool = True # use dadaptation optimizer
bitnet: bool = False # use bitnet
fp8: bool = False # use fp8
@dataclass() @dataclass()
class Config(_Config): class Config(_Config):
@ -636,6 +641,17 @@ class Config(_Config):
else: else:
self.optimizations = Optimizations(**self.optimizations) self.optimizations = Optimizations(**self.optimizations)
if self.hyperparameters.scheduler_type and not self.hyperparameters.scheduler:
self.hyperparameters.scheduler = self.hyperparameters.scheduler_type
self.hyperparameters.scheduler_type = ""
# do not combine the two
if self.hyperparameters.scheduler == "schedulefree" and self.optimizations.dadaptation:
self.hyperparameters.scheduler = ""
if self.hyperparameters.scheduler == "":
self.hyperparameters.torch_scheduler = True
# Preserves the old behavior # Preserves the old behavior
class NaiveTokenizer: class NaiveTokenizer:
def get_vocab( self ): def get_vocab( self ):

View File

@ -379,6 +379,11 @@ class Dataset(_Dataset):
path = random.choice(choices) path = random.choice(choices)
if cfg.dataset.use_hdf5: if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path) key = _get_hdf5_path(path)
if "audio" not in cfg.hdf5[key]:
_logger.warning("MISSING AUDIO:", key)
continue
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16) qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
else: else:
qnt = _load_quants(path) qnt = _load_quants(path)
@ -763,15 +768,15 @@ def create_dataset_metadata( skip_existing=True ):
name = str(dir) name = str(dir)
name = name.replace(root, "") name = name.replace(root, "")
# yucky
speaker_name = name speaker_name = name
if "LbriTTS-R" in speaker_name:
speaker_name = speaker_name.replace("LbriTTS-R", "LibriVox")
metadata_path = Path(f"{metadata_root}/{speaker_name}.json") metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
metadata_path.parents[0].mkdir(parents=True, exist_ok=True) metadata_path.parents[0].mkdir(parents=True, exist_ok=True)
metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read()) try:
metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read())
except Exception as e:
metadata = {}
if not os.path.isdir(f'{root}/{name}/'): if not os.path.isdir(f'{root}/{name}/'):
return return
@ -872,8 +877,8 @@ def create_dataset_hdf5( skip_existing=True ):
# yucky # yucky
speaker_name = name speaker_name = name
if "LbriTTS-R" in speaker_name: if "LibriTTS-R" in speaker_name:
speaker_name = speaker_name.replace("LbriTTS-R", "LibriVox") speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox")
metadata_path = Path(f"{metadata_root}/{speaker_name}.json") metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
metadata_path.parents[0].mkdir(parents=True, exist_ok=True) metadata_path.parents[0].mkdir(parents=True, exist_ok=True)
@ -899,10 +904,8 @@ def create_dataset_hdf5( skip_existing=True ):
key = f'{type}/{speaker_name}/{id}' key = f'{type}/{speaker_name}/{id}'
"""
if skip_existing and key in hf: if skip_existing and key in hf:
continue continue
"""
group = hf.create_group(key) if key not in hf else hf[key] group = hf.create_group(key) if key not in hf else hf[key]

View File

@ -143,7 +143,7 @@ def _load_vocos_model(device="cuda", levels=cfg.model.max_levels):
@cache @cache
def _load_dac_model(device="cuda", levels=cfg.model.max_levels): def _load_dac_model(device="cuda", levels=cfg.model.max_levels):
kwargs = dict(model_type="24khz",model_bitrate="8kbps",tag="latest") kwargs = dict(model_type="44khz",model_bitrate="8kbps",tag="latest")
""" """
if not cfg.variable_sample_rate: if not cfg.variable_sample_rate:
# yes there's a better way, something like f'{cfg.sample.rate//1000}hz' # yes there's a better way, something like f'{cfg.sample.rate//1000}hz'

View File

@ -45,11 +45,16 @@ def load_engines(training=True):
if inferencing: if inferencing:
model._cfg.training = False model._cfg.training = False
if (cfg.optimizations.bitsandbytes and cfg.optimizations.replace) or (cfg.optimizations.fp8): if cfg.optimizations.replace and cfg.optimizations.linear:
model.model = ml.replace_linear( model.model ) model.model = ml.replace_linear( model.model )
if cfg.optimizations.replace and cfg.optimizations.embedding:
model.model = ml.replace_embedding( model.model )
if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer): if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
optimizer_class = None optimizer_class = None
scheduler_class = None
params = { params = {
"lr": cfg.hyperparameters.learning_rate, "lr": cfg.hyperparameters.learning_rate,
} }
@ -58,6 +63,10 @@ def load_engines(training=True):
params["eps"] = 1e-07 params["eps"] = 1e-07
params["weight_decay"] = 0.01 params["weight_decay"] = 0.01
# for dadaptation since it has Adam only
if ml.AdamW == ml.Adam:
params["decouple"] = True
optimizer_class = ml.AdamW optimizer_class = ml.AdamW
elif cfg.hyperparameters.optimizer.lower() == "sgd": elif cfg.hyperparameters.optimizer.lower() == "sgd":
optimizer = ml.SGD optimizer = ml.SGD
@ -72,11 +81,27 @@ def load_engines(training=True):
raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}') raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
params.update(cfg.hyperparameters.optimizer_params) params.update(cfg.hyperparameters.optimizer_params)
optimizer = optimizer_class( optimizer = optimizer_class(
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ], [ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
**params, **params,
) )
if cfg.hyperparameters.scheduler.lower() == "schedulefree":
if cfg.hyperparameters.optimizer.lower() == "adamw":
scheduler_class = ml.schedulefree.AdamWScheduleFree
elif cfg.hyperparameters.optimizer.lower() == "sgd":
scheduler_class = ml.schedulefree.SGDScheduleFree
else:
raise ValueError(f'ScheduleFree not implemented with requested optimizer: {cfg.hyperparameters.optimizer}')
optimizer = scheduler_class(
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
lr = params['lr']
)
# set up our LR scheduler here # set up our LR scheduler here
if inferencing: if inferencing:

View File

@ -365,7 +365,7 @@ def example_usage():
'n_tokens': 1024, 'n_tokens': 1024,
'd_model': 1024, # 256, # 1024, # 1536 'd_model': 1024, # 256, # 1024, # 1536
'n_heads': 16, # 4, # 16, # 24 'n_heads': 16, # 4, # 16, # 24
'n_layers': 12, # 32 'n_layers': 4, # 32
'n_experts': 1, 'n_experts': 1,
'l_padding': 8 if cfg.optimizations.fp8 else 0, 'l_padding': 8 if cfg.optimizations.fp8 else 0,
@ -381,16 +381,66 @@ def example_usage():
""" """
model = AR_NAR(**kwargs).to(device) model = AR_NAR(**kwargs).to(device)
steps = 100 steps = 1000
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
#optimizer = ml.Adagrad(model.parameters(), lr=1.0e-2)
#optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else ""
learning_rate = cfg.hyperparameters.learning_rate if cfg.cfg_path is not None else None
if cfg.optimizations.dadaptation:
# do not combine the two
if scheduler == "schedulefree":
scheduler = ""
learning_rate = 1.0
if optimizer == "prodigy":
if learning_rate is None:
learning_rate = 1.0
optimizer = ml.Prodigy
elif optimizer == "adagrad":
if learning_rate is None:
learning_rate = 1.0e-2
optimizer = ml.Adagrad
elif optimizer == "adamw":
if learning_rate is None:
learning_rate = 1.0e-4
optimizer = ml.AdamW
elif optimizer == "sdg":
if learning_rate is None:
learning_rate = 1.0e-4
optimizer = ml.SGD
else:
raise ValueError(f"Unrecognized optimizer: {optimizer}")
print("Optimizer:", optimizer, "\tLearning rate:", learning_rate)
optimizer = optimizer(model.parameters(), lr=learning_rate)
if scheduler == "schedulefree":
if isinstance(optimizer, ml.AdamW):
scheduler = ml.schedulefree.AdamWScheduleFree
elif isinstance(optimizer, ml.SGD):
scheduler = ml.schedulefree.SGDScheduleFree
else:
scheduler = None
if scheduler is not None:
print("Scheduler:", scheduler)
optimizer = scheduler( model.parameters(), lr = learning_rate )
if cfg.optimizations.replace and cfg.optimizations.linear:
model = ml.replace_linear( model )
if cfg.optimizations.replace and cfg.optimizations.embedding:
model = ml.replace_embedding( model )
engine = Engine(model=model, optimizer=optimizer) engine = Engine(model=model, optimizer=optimizer)
if (cfg.optimizations.bitsandbytes and cfg.optimizations.replace) or (cfg.optimizations.fp8):
model.model = ml.replace_linear( model.model )
torch.save( { torch.save( {
'module': model.state_dict() 'module': model.state_dict()
}, "./data/test.pth" ) }, "./data/test.pth" )

View File

@ -16,6 +16,8 @@ from torch.nn.utils.rnn import pad_sequence
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
from ..utils import wrapper as ml
from ..samplers import reptition_penalize, length_penalize, top_k_top_p_filtering, dynamic_temperature, top_k_logits_list, mirostat_sample from ..samplers import reptition_penalize, length_penalize, top_k_top_p_filtering, dynamic_temperature, top_k_logits_list, mirostat_sample
try: try:
@ -191,48 +193,9 @@ try:
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
LLAMA_ATTENTIONS["xformers"] = LLamaXformersAttention
except Exception as e: except Exception as e:
print("Error creating `LLamaXformersAttention`:", e) print("Error creating `LLamaXformersAttention`:", e)
def replace_attention( model, impl, verbose=False ):
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
attentions = [k.split('.') for k, m in model.named_modules() if isinstance(m, LlamaAttention)]
if impl not in LLAMA_ATTENTIONS:
print(f"Attention '{imp} is not in LLAMA_ATTENTIONS'")
return model
klass = LLAMA_ATTENTIONS[impl]
for *parent, k in attentions:
name = '.'.join(parent)
# copy parameters
m = getattr( model.get_submodule(name), k )
if isinstance(m, klass):
continue
config = m.config
layer_idx = m.layer_idx
kwargs = dict(config=config, layer_idx=layer_idx)
# overwrite
setattr(
model.get_submodule(name), k,
klass( **kwargs ).to(device=device, dtype=dtype)
)
if verbose:
print(f"Replacing {name}.{k} to", klass)
return model
def _create_mask(l, device): def _create_mask(l, device):
"""1 is valid region and 0 is invalid.""" """1 is valid region and 0 is invalid."""
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t) seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
@ -485,6 +448,14 @@ class Base(nn.Module):
self.sep = nn.Parameter(torch.randn(d_model)) self.sep = nn.Parameter(torch.randn(d_model))
# ick, there has to be a better way
attention = self.config.attention if self.config is not None else None
use_xformers = False
if attention == "xformers":
use_xformers = True
attention = None
if self.arch_type == "transformer": if self.arch_type == "transformer":
self.sin_emb = SinusoidalEmbedding(d_model) self.sin_emb = SinusoidalEmbedding(d_model)
self.blocks = nn.ModuleList([TransformerBlock( self.blocks = nn.ModuleList([TransformerBlock(
@ -495,7 +466,7 @@ class Base(nn.Module):
norm_type=self.norm_type, norm_type=self.norm_type,
n_levels=self.n_resp_levels, n_levels=self.n_resp_levels,
) for _ in range(n_layers) ]) ) for _ in range(n_layers) ])
elif self.arch_type == "mistral" or self.arch_type == "mixtral": elif self.arch_type in ["mistral", "mixtral"]:
if n_experts <= 1: if n_experts <= 1:
self.model = MistralModel(MistralConfig( self.model = MistralModel(MistralConfig(
vocab_size=n_resp_tokens, vocab_size=n_resp_tokens,
@ -509,7 +480,7 @@ class Base(nn.Module):
hidden_act="gelu", hidden_act="gelu",
is_encoder_decoder=False, is_encoder_decoder=False,
is_decoder=True, is_decoder=True,
attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2", attn_implementation=attention,
)) ))
else: else:
self.model = MixtralModel(MixtralConfig( self.model = MixtralModel(MixtralConfig(
@ -528,18 +499,10 @@ class Base(nn.Module):
is_decoder=True, is_decoder=True,
num_local_experts=n_experts, num_local_experts=n_experts,
num_experts_per_tok=min(2, n_experts), num_experts_per_tok=min(2, n_experts),
attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2", attn_implementation=attention,
)) ))
elif self.arch_type == "llama": elif self.arch_type == "llama":
if n_experts <= 1: if n_experts <= 1:
# ick, there has to be a better way
attention = self.config.attention if self.config is not None else None # "flash_attention_2",
use_xformers = False
if attention == "xformers":
use_xformers = True
attention = None
self.model = LlamaModel(LlamaConfig( self.model = LlamaModel(LlamaConfig(
vocab_size=n_resp_tokens, vocab_size=n_resp_tokens,
hidden_size=d_model, hidden_size=d_model,
@ -555,9 +518,6 @@ class Base(nn.Module):
is_decoder=True, is_decoder=True,
attn_implementation=attention, attn_implementation=attention,
)) ))
if use_xformers:
self.model = replace_attention( self.model, "xformers" if use_xformers else attention )
else: else:
self.model = MixtralModel(MixtralConfig( self.model = MixtralModel(MixtralConfig(
vocab_size =n_resp_tokens, vocab_size =n_resp_tokens,
@ -575,9 +535,8 @@ class Base(nn.Module):
is_decoder=True, is_decoder=True,
num_local_experts=n_experts, num_local_experts=n_experts,
num_experts_per_tok=min(2, n_experts), num_experts_per_tok=min(2, n_experts),
attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2", attn_implementation=attention,
)) ))
elif self.arch_type == "retnet": elif self.arch_type == "retnet":
kwargs = dict( kwargs = dict(
vocab_size=n_resp_tokens, vocab_size=n_resp_tokens,
@ -589,9 +548,9 @@ class Base(nn.Module):
dropout=p_dropout if training else 0.0, dropout=p_dropout if training else 0.0,
checkpoint_activations=self.activation_checkpointing, checkpoint_activations=self.activation_checkpointing,
activation_fn="gelu", activation_fn="gelu",
use_layernorm=True, # self.version < 3, use_layernorm=self.version < 3,
use_biases=True, # self.version < 3, use_biases=self.version < 3,
use_glu=False, # self.version >= 3, use_glu=self.version >= 3,
chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0, chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0,
recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0, recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,
@ -642,6 +601,9 @@ class Base(nn.Module):
else: else:
raise RuntimeError(f'Unknown arch specified: {self.arch_type}') raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
if use_xformers:
self.model = ml.replace_attention( self.model, klass=LLamaXformersAttention, target=LlamaAttention )
self.classifier = nn.Linear(d_model, n_resp_tokens) self.classifier = nn.Linear(d_model, n_resp_tokens)
self.accuracy_metric = MulticlassAccuracy( self.accuracy_metric = MulticlassAccuracy(

View File

@ -9,6 +9,11 @@ from ..config import cfg
Embedding = torch.nn.Embedding Embedding = torch.nn.Embedding
Linear = torch.nn.Linear Linear = torch.nn.Linear
Adam = torch.optim.Adam
AdamW = torch.optim.AdamW
SGD = torch.optim.SGD
Adagrad = torch.optim.Adagrad
# https://github.com/kyegomez/BitNet # https://github.com/kyegomez/BitNet
if cfg.optimizations.bitnet: if cfg.optimizations.bitnet:
from bitnet import BitLinear from bitnet import BitLinear
@ -37,19 +42,20 @@ if cfg.optimizations.bitsandbytes:
)).to(self.weight.dtype) ) )).to(self.weight.dtype) )
""" """
if cfg.optimizations.optimizers:
Adam = bnb.optim.Adam8bit
AdamW = bnb.optim.AdamW8bit
SGD = bnb.optim.SGD8bit
Adagrad = bnb.optim.Adagrad8bit
if cfg.optimizations.bitsandbytes: elif cfg.optimizations.dadaptation:
import bitsandbytes as bnb import dadaptation
Adam = bnb.optim.Adam8bit if cfg.optimizations.optimizers:
AdamW = bnb.optim.AdamW8bit Adam = dadaptation.DAdaptAdam
SGD = bnb.optim.SGD8bit AdamW = dadaptation.DAdaptAdam
Adagrad = bnb.optim.Adagrad8bit SGD = dadaptation.DAdaptSGD
else: AdaGrad = dadaptation.DAdaptAdaGrad
Adam = torch.optim.Adam
AdamW = torch.optim.AdamW
SGD = torch.optim.SGD
Adagrad = torch.optim.Adagrad
# handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16) # handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16)
@contextmanager @contextmanager
@ -92,42 +98,112 @@ else:
def autocast(): def autocast():
yield torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp) yield torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp)
if cfg.optimizations.injects and cfg.optimizations.bitsandbytes: if cfg.optimizations.injects:
torch.nn.Linear = Linear if cfg.optimizations.linear:
torch.nn.Embedding = Embedding torch.nn.Linear = Linear
if cfg.optimizations.embedding:
torch.nn.Embedding = Embedding
torch.optim.Adam = Adam if cfg.optimizations.optimizers:
torch.optim.AdamW = AdamW torch.optim.Adam = Adam
torch.optim.SGD = SGD torch.optim.AdamW = AdamW
torch.optim.SGD = SGD
# disgusting kludge, but it works (just realized BitNet has its own replacement routine) # disgusting kludge, but it works (just realized BitNet has its own replacement routine)
def replace_linear( model, verbose=False ): # generalizing this would be super sugoi but the there's no catch all for arguments
def replace_linear( model, klass=Linear, target=torch.nn.Linear, verbose=False ):
bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet
device = next(model.parameters()).device device = next(model.parameters()).device
linears = [k.split('.') for k, m in model.named_modules() if isinstance(m, torch.nn.Linear)] dtype = next(model.parameters()).dtype
klass = Linear modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
for *parent, k in linears: for *parent, k in modules:
name = '.'.join(parent) name = '.'.join(parent)
# copy parameters
m = getattr( model.get_submodule(name), k ) m = getattr( model.get_submodule(name), k )
if isinstance(m, klass): if isinstance(m, klass):
continue continue
in_features = m.in_features kwargs = dict(
out_features = m.out_features in_features = m.in_features,
bias = m.bias is not None out_features = m.out_features,
bias = m.bias is not None,
kwargs = dict(in_features=in_features, out_features=out_features, bias=bias) if not bnb else dict(input_features=in_features, output_features=out_features, bias=bias) ) if not bnb else dict(
input_features=m.in_features,
output_features=m.out_features,
bias=m.bias is not None,
)
# overwrite # overwrite
setattr( setattr(
model.get_submodule(name), k, model.get_submodule(name), k,
klass( **kwargs ).to(device=device, dtype=cfg.trainer.dtype) klass( **kwargs ).to(device=device, dtype=dtype)
)
if verbose:
print(f"Replacing {name}.{k} to", klass)
return model
def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbose=False ):
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
for *parent, k in modules:
name = '.'.join(parent)
m = getattr( model.get_submodule(name), k )
if isinstance(m, klass):
continue
kwargs = dict(
num_embeddings=m.num_embeddings,
embedding_dim=m.embedding_dim,
padding_idx=m.padding_idx,
max_norm=m.max_norm,
norm_type=m.norm_type,
scale_grad_by_freq=m.scale_grad_by_freq,
sparse=m.sparse,
)
# overwrite
setattr(
model.get_submodule(name), k,
klass( **kwargs ).to(device=device, dtype=dtype)
)
if verbose:
print(f"Replacing {name}.{k} to", klass)
return model
# cannot feasibly do default arguments here sad
def replace_attention( model, klass, target, verbose=False ):
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
for *parent, k in modules:
name = '.'.join(parent)
m = getattr( model.get_submodule(name), k )
if isinstance(m, klass):
continue
kwargs = dict(
config = m.config,
layer_idx = m.layer_idx,
)
# overwrite
setattr(
model.get_submodule(name), k,
klass( **kwargs ).to(device=device, dtype=dtype)
) )
if verbose: if verbose:
@ -139,4 +215,12 @@ def replace_linear( model, verbose=False ):
try: try:
from prodigyopt import Prodigy from prodigyopt import Prodigy
except Exception as e: except Exception as e:
print('Error while importing Prodigyopt:', str(e))
pass
# https://github.com/facebookresearch/schedule_free/
try:
import schedulefree
except Exception as e:
print('Error while importing Schedule_Free:', str(e))
pass pass