crammed in DAdaptation (doesn't seem worth it) and ScheduleFree (forgot I wanted to weeks ago, seems promising), optimization wrapper cleanup, test trainer changes, etc.
This commit is contained in:
parent
c6e0f905b5
commit
0d5d545a40
|
@ -306,11 +306,14 @@ class Hyperparameters:
|
||||||
|
|
||||||
optimizer: str = "Adamw"
|
optimizer: str = "Adamw"
|
||||||
torch_optimizer: bool = False
|
torch_optimizer: bool = False
|
||||||
|
|
||||||
optimizer_params: dict = field(default_factory=lambda: {})
|
optimizer_params: dict = field(default_factory=lambda: {})
|
||||||
learning_rate: float = 3.25e-4
|
learning_rate: float = 3.25e-4
|
||||||
|
|
||||||
scheduler_type: str = ""
|
scheduler: str = ""
|
||||||
|
scheduler_type: str = "" # deprecated
|
||||||
scheduler_params: dict = field(default_factory=lambda: {})
|
scheduler_params: dict = field(default_factory=lambda: {})
|
||||||
|
torch_scheduler: bool = False
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Evaluation:
|
class Evaluation:
|
||||||
|
@ -337,7 +340,7 @@ class DeepSpeed:
|
||||||
for k in cfg.hyperparameters.scheduler_params:
|
for k in cfg.hyperparameters.scheduler_params:
|
||||||
scheduler_params[k] = cfg.hyperparameters.scheduler_params[k]
|
scheduler_params[k] = cfg.hyperparameters.scheduler_params[k]
|
||||||
|
|
||||||
if cfg.hyperparameters.scheduler_type == "WarmupDecayLR" and 'total_num_steps' not in scheduler_params:
|
if cfg.hyperparameters.scheduler == "WarmupDecayLR" and 'total_num_steps' not in scheduler_params:
|
||||||
scheduler_params['total_num_steps'] = cfg.trainer.iterations
|
scheduler_params['total_num_steps'] = cfg.trainer.iterations
|
||||||
|
|
||||||
ds_cfg = {
|
ds_cfg = {
|
||||||
|
@ -350,9 +353,9 @@ class DeepSpeed:
|
||||||
}
|
}
|
||||||
} if not cfg.hyperparameters.torch_optimizer else None,
|
} if not cfg.hyperparameters.torch_optimizer else None,
|
||||||
"scheduler": {
|
"scheduler": {
|
||||||
"type": cfg.hyperparameters.scheduler_type,
|
"type": cfg.hyperparameters.scheduler,
|
||||||
"params": scheduler_params,
|
"params": scheduler_params,
|
||||||
} if cfg.hyperparameters.scheduler_type != "" else None,
|
} if not cfg.hyperparameters.torch_scheduler else None,
|
||||||
"gradient_clipping": cfg.hyperparameters.gradient_clipping,
|
"gradient_clipping": cfg.hyperparameters.gradient_clipping,
|
||||||
"fp16": {
|
"fp16": {
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
|
@ -544,15 +547,17 @@ class Inference:
|
||||||
# should be renamed to optimizations
|
# should be renamed to optimizations
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Optimizations:
|
class Optimizations:
|
||||||
bitsandbytes: bool = False
|
injects: bool = False # overwrites default torch classes (not recommended)
|
||||||
injects: bool = False
|
replace: bool = False # replaces modules in place with the optimized version (recommended)
|
||||||
replace: bool = False
|
|
||||||
|
|
||||||
linear: bool = True
|
linear: bool = True # inject/replace linear for BnB
|
||||||
embedding: bool = True
|
embedding: bool = True # inject/replace embedding for BnB
|
||||||
|
optimizers: bool = True # inject/replace optimizers (BnB, DAdaptation)
|
||||||
|
|
||||||
bitnet: bool = False
|
bitsandbytes: bool = False # use bitsandbytes
|
||||||
fp8: bool = False
|
dadaptation: bool = True # use dadaptation optimizer
|
||||||
|
bitnet: bool = False # use bitnet
|
||||||
|
fp8: bool = False # use fp8
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Config(_Config):
|
class Config(_Config):
|
||||||
|
@ -636,6 +641,17 @@ class Config(_Config):
|
||||||
else:
|
else:
|
||||||
self.optimizations = Optimizations(**self.optimizations)
|
self.optimizations = Optimizations(**self.optimizations)
|
||||||
|
|
||||||
|
if self.hyperparameters.scheduler_type and not self.hyperparameters.scheduler:
|
||||||
|
self.hyperparameters.scheduler = self.hyperparameters.scheduler_type
|
||||||
|
self.hyperparameters.scheduler_type = ""
|
||||||
|
|
||||||
|
# do not combine the two
|
||||||
|
if self.hyperparameters.scheduler == "schedulefree" and self.optimizations.dadaptation:
|
||||||
|
self.hyperparameters.scheduler = ""
|
||||||
|
|
||||||
|
if self.hyperparameters.scheduler == "":
|
||||||
|
self.hyperparameters.torch_scheduler = True
|
||||||
|
|
||||||
# Preserves the old behavior
|
# Preserves the old behavior
|
||||||
class NaiveTokenizer:
|
class NaiveTokenizer:
|
||||||
def get_vocab( self ):
|
def get_vocab( self ):
|
||||||
|
|
|
@ -379,6 +379,11 @@ class Dataset(_Dataset):
|
||||||
path = random.choice(choices)
|
path = random.choice(choices)
|
||||||
if cfg.dataset.use_hdf5:
|
if cfg.dataset.use_hdf5:
|
||||||
key = _get_hdf5_path(path)
|
key = _get_hdf5_path(path)
|
||||||
|
|
||||||
|
if "audio" not in cfg.hdf5[key]:
|
||||||
|
_logger.warning("MISSING AUDIO:", key)
|
||||||
|
continue
|
||||||
|
|
||||||
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
|
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
|
||||||
else:
|
else:
|
||||||
qnt = _load_quants(path)
|
qnt = _load_quants(path)
|
||||||
|
@ -763,15 +768,15 @@ def create_dataset_metadata( skip_existing=True ):
|
||||||
name = str(dir)
|
name = str(dir)
|
||||||
name = name.replace(root, "")
|
name = name.replace(root, "")
|
||||||
|
|
||||||
# yucky
|
|
||||||
speaker_name = name
|
speaker_name = name
|
||||||
if "LbriTTS-R" in speaker_name:
|
|
||||||
speaker_name = speaker_name.replace("LbriTTS-R", "LibriVox")
|
|
||||||
|
|
||||||
metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
|
metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
|
||||||
metadata_path.parents[0].mkdir(parents=True, exist_ok=True)
|
metadata_path.parents[0].mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read())
|
try:
|
||||||
|
metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read())
|
||||||
|
except Exception as e:
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
if not os.path.isdir(f'{root}/{name}/'):
|
if not os.path.isdir(f'{root}/{name}/'):
|
||||||
return
|
return
|
||||||
|
@ -872,8 +877,8 @@ def create_dataset_hdf5( skip_existing=True ):
|
||||||
|
|
||||||
# yucky
|
# yucky
|
||||||
speaker_name = name
|
speaker_name = name
|
||||||
if "LbriTTS-R" in speaker_name:
|
if "LibriTTS-R" in speaker_name:
|
||||||
speaker_name = speaker_name.replace("LbriTTS-R", "LibriVox")
|
speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox")
|
||||||
|
|
||||||
metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
|
metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
|
||||||
metadata_path.parents[0].mkdir(parents=True, exist_ok=True)
|
metadata_path.parents[0].mkdir(parents=True, exist_ok=True)
|
||||||
|
@ -899,10 +904,8 @@ def create_dataset_hdf5( skip_existing=True ):
|
||||||
|
|
||||||
key = f'{type}/{speaker_name}/{id}'
|
key = f'{type}/{speaker_name}/{id}'
|
||||||
|
|
||||||
"""
|
|
||||||
if skip_existing and key in hf:
|
if skip_existing and key in hf:
|
||||||
continue
|
continue
|
||||||
"""
|
|
||||||
|
|
||||||
group = hf.create_group(key) if key not in hf else hf[key]
|
group = hf.create_group(key) if key not in hf else hf[key]
|
||||||
|
|
||||||
|
|
|
@ -143,7 +143,7 @@ def _load_vocos_model(device="cuda", levels=cfg.model.max_levels):
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def _load_dac_model(device="cuda", levels=cfg.model.max_levels):
|
def _load_dac_model(device="cuda", levels=cfg.model.max_levels):
|
||||||
kwargs = dict(model_type="24khz",model_bitrate="8kbps",tag="latest")
|
kwargs = dict(model_type="44khz",model_bitrate="8kbps",tag="latest")
|
||||||
"""
|
"""
|
||||||
if not cfg.variable_sample_rate:
|
if not cfg.variable_sample_rate:
|
||||||
# yes there's a better way, something like f'{cfg.sample.rate//1000}hz'
|
# yes there's a better way, something like f'{cfg.sample.rate//1000}hz'
|
||||||
|
|
|
@ -45,11 +45,16 @@ def load_engines(training=True):
|
||||||
if inferencing:
|
if inferencing:
|
||||||
model._cfg.training = False
|
model._cfg.training = False
|
||||||
|
|
||||||
if (cfg.optimizations.bitsandbytes and cfg.optimizations.replace) or (cfg.optimizations.fp8):
|
if cfg.optimizations.replace and cfg.optimizations.linear:
|
||||||
model.model = ml.replace_linear( model.model )
|
model.model = ml.replace_linear( model.model )
|
||||||
|
|
||||||
|
if cfg.optimizations.replace and cfg.optimizations.embedding:
|
||||||
|
model.model = ml.replace_embedding( model.model )
|
||||||
|
|
||||||
if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
|
if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
|
||||||
optimizer_class = None
|
optimizer_class = None
|
||||||
|
scheduler_class = None
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"lr": cfg.hyperparameters.learning_rate,
|
"lr": cfg.hyperparameters.learning_rate,
|
||||||
}
|
}
|
||||||
|
@ -58,6 +63,10 @@ def load_engines(training=True):
|
||||||
params["eps"] = 1e-07
|
params["eps"] = 1e-07
|
||||||
params["weight_decay"] = 0.01
|
params["weight_decay"] = 0.01
|
||||||
|
|
||||||
|
# for dadaptation since it has Adam only
|
||||||
|
if ml.AdamW == ml.Adam:
|
||||||
|
params["decouple"] = True
|
||||||
|
|
||||||
optimizer_class = ml.AdamW
|
optimizer_class = ml.AdamW
|
||||||
elif cfg.hyperparameters.optimizer.lower() == "sgd":
|
elif cfg.hyperparameters.optimizer.lower() == "sgd":
|
||||||
optimizer = ml.SGD
|
optimizer = ml.SGD
|
||||||
|
@ -72,11 +81,27 @@ def load_engines(training=True):
|
||||||
raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
|
raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
|
||||||
|
|
||||||
params.update(cfg.hyperparameters.optimizer_params)
|
params.update(cfg.hyperparameters.optimizer_params)
|
||||||
|
|
||||||
optimizer = optimizer_class(
|
optimizer = optimizer_class(
|
||||||
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
|
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
|
||||||
**params,
|
**params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.hyperparameters.scheduler.lower() == "schedulefree":
|
||||||
|
if cfg.hyperparameters.optimizer.lower() == "adamw":
|
||||||
|
scheduler_class = ml.schedulefree.AdamWScheduleFree
|
||||||
|
elif cfg.hyperparameters.optimizer.lower() == "sgd":
|
||||||
|
scheduler_class = ml.schedulefree.SGDScheduleFree
|
||||||
|
else:
|
||||||
|
raise ValueError(f'ScheduleFree not implemented with requested optimizer: {cfg.hyperparameters.optimizer}')
|
||||||
|
|
||||||
|
optimizer = scheduler_class(
|
||||||
|
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
|
||||||
|
lr = params['lr']
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# set up our LR scheduler here
|
# set up our LR scheduler here
|
||||||
|
|
||||||
if inferencing:
|
if inferencing:
|
||||||
|
|
|
@ -365,7 +365,7 @@ def example_usage():
|
||||||
'n_tokens': 1024,
|
'n_tokens': 1024,
|
||||||
'd_model': 1024, # 256, # 1024, # 1536
|
'd_model': 1024, # 256, # 1024, # 1536
|
||||||
'n_heads': 16, # 4, # 16, # 24
|
'n_heads': 16, # 4, # 16, # 24
|
||||||
'n_layers': 12, # 32
|
'n_layers': 4, # 32
|
||||||
'n_experts': 1,
|
'n_experts': 1,
|
||||||
|
|
||||||
'l_padding': 8 if cfg.optimizations.fp8 else 0,
|
'l_padding': 8 if cfg.optimizations.fp8 else 0,
|
||||||
|
@ -381,16 +381,66 @@ def example_usage():
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model = AR_NAR(**kwargs).to(device)
|
model = AR_NAR(**kwargs).to(device)
|
||||||
steps = 100
|
steps = 1000
|
||||||
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
|
|
||||||
#optimizer = ml.Adagrad(model.parameters(), lr=1.0e-2)
|
|
||||||
#optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
|
|
||||||
|
|
||||||
|
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy"
|
||||||
|
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else ""
|
||||||
|
learning_rate = cfg.hyperparameters.learning_rate if cfg.cfg_path is not None else None
|
||||||
|
|
||||||
|
if cfg.optimizations.dadaptation:
|
||||||
|
# do not combine the two
|
||||||
|
if scheduler == "schedulefree":
|
||||||
|
scheduler = ""
|
||||||
|
|
||||||
|
learning_rate = 1.0
|
||||||
|
|
||||||
|
if optimizer == "prodigy":
|
||||||
|
if learning_rate is None:
|
||||||
|
learning_rate = 1.0
|
||||||
|
|
||||||
|
optimizer = ml.Prodigy
|
||||||
|
elif optimizer == "adagrad":
|
||||||
|
if learning_rate is None:
|
||||||
|
learning_rate = 1.0e-2
|
||||||
|
|
||||||
|
optimizer = ml.Adagrad
|
||||||
|
elif optimizer == "adamw":
|
||||||
|
if learning_rate is None:
|
||||||
|
learning_rate = 1.0e-4
|
||||||
|
|
||||||
|
optimizer = ml.AdamW
|
||||||
|
elif optimizer == "sdg":
|
||||||
|
if learning_rate is None:
|
||||||
|
learning_rate = 1.0e-4
|
||||||
|
|
||||||
|
optimizer = ml.SGD
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unrecognized optimizer: {optimizer}")
|
||||||
|
|
||||||
|
print("Optimizer:", optimizer, "\tLearning rate:", learning_rate)
|
||||||
|
|
||||||
|
optimizer = optimizer(model.parameters(), lr=learning_rate)
|
||||||
|
|
||||||
|
if scheduler == "schedulefree":
|
||||||
|
if isinstance(optimizer, ml.AdamW):
|
||||||
|
scheduler = ml.schedulefree.AdamWScheduleFree
|
||||||
|
elif isinstance(optimizer, ml.SGD):
|
||||||
|
scheduler = ml.schedulefree.SGDScheduleFree
|
||||||
|
else:
|
||||||
|
scheduler = None
|
||||||
|
|
||||||
|
if scheduler is not None:
|
||||||
|
print("Scheduler:", scheduler)
|
||||||
|
optimizer = scheduler( model.parameters(), lr = learning_rate )
|
||||||
|
|
||||||
|
if cfg.optimizations.replace and cfg.optimizations.linear:
|
||||||
|
model = ml.replace_linear( model )
|
||||||
|
|
||||||
|
if cfg.optimizations.replace and cfg.optimizations.embedding:
|
||||||
|
model = ml.replace_embedding( model )
|
||||||
|
|
||||||
engine = Engine(model=model, optimizer=optimizer)
|
engine = Engine(model=model, optimizer=optimizer)
|
||||||
|
|
||||||
if (cfg.optimizations.bitsandbytes and cfg.optimizations.replace) or (cfg.optimizations.fp8):
|
|
||||||
model.model = ml.replace_linear( model.model )
|
|
||||||
|
|
||||||
torch.save( {
|
torch.save( {
|
||||||
'module': model.state_dict()
|
'module': model.state_dict()
|
||||||
}, "./data/test.pth" )
|
}, "./data/test.pth" )
|
||||||
|
|
|
@ -16,6 +16,8 @@ from torch.nn.utils.rnn import pad_sequence
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
|
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
|
||||||
|
|
||||||
|
from ..utils import wrapper as ml
|
||||||
|
|
||||||
from ..samplers import reptition_penalize, length_penalize, top_k_top_p_filtering, dynamic_temperature, top_k_logits_list, mirostat_sample
|
from ..samplers import reptition_penalize, length_penalize, top_k_top_p_filtering, dynamic_temperature, top_k_logits_list, mirostat_sample
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -191,48 +193,9 @@ try:
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
LLAMA_ATTENTIONS["xformers"] = LLamaXformersAttention
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error creating `LLamaXformersAttention`:", e)
|
print("Error creating `LLamaXformersAttention`:", e)
|
||||||
|
|
||||||
def replace_attention( model, impl, verbose=False ):
|
|
||||||
device = next(model.parameters()).device
|
|
||||||
dtype = next(model.parameters()).dtype
|
|
||||||
attentions = [k.split('.') for k, m in model.named_modules() if isinstance(m, LlamaAttention)]
|
|
||||||
|
|
||||||
if impl not in LLAMA_ATTENTIONS:
|
|
||||||
print(f"Attention '{imp} is not in LLAMA_ATTENTIONS'")
|
|
||||||
return model
|
|
||||||
|
|
||||||
klass = LLAMA_ATTENTIONS[impl]
|
|
||||||
|
|
||||||
for *parent, k in attentions:
|
|
||||||
name = '.'.join(parent)
|
|
||||||
|
|
||||||
# copy parameters
|
|
||||||
m = getattr( model.get_submodule(name), k )
|
|
||||||
|
|
||||||
if isinstance(m, klass):
|
|
||||||
continue
|
|
||||||
|
|
||||||
config = m.config
|
|
||||||
layer_idx = m.layer_idx
|
|
||||||
|
|
||||||
kwargs = dict(config=config, layer_idx=layer_idx)
|
|
||||||
|
|
||||||
# overwrite
|
|
||||||
setattr(
|
|
||||||
model.get_submodule(name), k,
|
|
||||||
klass( **kwargs ).to(device=device, dtype=dtype)
|
|
||||||
)
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print(f"Replacing {name}.{k} to", klass)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
def _create_mask(l, device):
|
def _create_mask(l, device):
|
||||||
"""1 is valid region and 0 is invalid."""
|
"""1 is valid region and 0 is invalid."""
|
||||||
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
|
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
|
||||||
|
@ -485,6 +448,14 @@ class Base(nn.Module):
|
||||||
|
|
||||||
self.sep = nn.Parameter(torch.randn(d_model))
|
self.sep = nn.Parameter(torch.randn(d_model))
|
||||||
|
|
||||||
|
# ick, there has to be a better way
|
||||||
|
attention = self.config.attention if self.config is not None else None
|
||||||
|
use_xformers = False
|
||||||
|
|
||||||
|
if attention == "xformers":
|
||||||
|
use_xformers = True
|
||||||
|
attention = None
|
||||||
|
|
||||||
if self.arch_type == "transformer":
|
if self.arch_type == "transformer":
|
||||||
self.sin_emb = SinusoidalEmbedding(d_model)
|
self.sin_emb = SinusoidalEmbedding(d_model)
|
||||||
self.blocks = nn.ModuleList([TransformerBlock(
|
self.blocks = nn.ModuleList([TransformerBlock(
|
||||||
|
@ -495,7 +466,7 @@ class Base(nn.Module):
|
||||||
norm_type=self.norm_type,
|
norm_type=self.norm_type,
|
||||||
n_levels=self.n_resp_levels,
|
n_levels=self.n_resp_levels,
|
||||||
) for _ in range(n_layers) ])
|
) for _ in range(n_layers) ])
|
||||||
elif self.arch_type == "mistral" or self.arch_type == "mixtral":
|
elif self.arch_type in ["mistral", "mixtral"]:
|
||||||
if n_experts <= 1:
|
if n_experts <= 1:
|
||||||
self.model = MistralModel(MistralConfig(
|
self.model = MistralModel(MistralConfig(
|
||||||
vocab_size=n_resp_tokens,
|
vocab_size=n_resp_tokens,
|
||||||
|
@ -509,7 +480,7 @@ class Base(nn.Module):
|
||||||
hidden_act="gelu",
|
hidden_act="gelu",
|
||||||
is_encoder_decoder=False,
|
is_encoder_decoder=False,
|
||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2",
|
attn_implementation=attention,
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
self.model = MixtralModel(MixtralConfig(
|
self.model = MixtralModel(MixtralConfig(
|
||||||
|
@ -528,18 +499,10 @@ class Base(nn.Module):
|
||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
num_local_experts=n_experts,
|
num_local_experts=n_experts,
|
||||||
num_experts_per_tok=min(2, n_experts),
|
num_experts_per_tok=min(2, n_experts),
|
||||||
attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2",
|
attn_implementation=attention,
|
||||||
))
|
))
|
||||||
elif self.arch_type == "llama":
|
elif self.arch_type == "llama":
|
||||||
if n_experts <= 1:
|
if n_experts <= 1:
|
||||||
# ick, there has to be a better way
|
|
||||||
attention = self.config.attention if self.config is not None else None # "flash_attention_2",
|
|
||||||
use_xformers = False
|
|
||||||
|
|
||||||
if attention == "xformers":
|
|
||||||
use_xformers = True
|
|
||||||
attention = None
|
|
||||||
|
|
||||||
self.model = LlamaModel(LlamaConfig(
|
self.model = LlamaModel(LlamaConfig(
|
||||||
vocab_size=n_resp_tokens,
|
vocab_size=n_resp_tokens,
|
||||||
hidden_size=d_model,
|
hidden_size=d_model,
|
||||||
|
@ -555,9 +518,6 @@ class Base(nn.Module):
|
||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
attn_implementation=attention,
|
attn_implementation=attention,
|
||||||
))
|
))
|
||||||
|
|
||||||
if use_xformers:
|
|
||||||
self.model = replace_attention( self.model, "xformers" if use_xformers else attention )
|
|
||||||
else:
|
else:
|
||||||
self.model = MixtralModel(MixtralConfig(
|
self.model = MixtralModel(MixtralConfig(
|
||||||
vocab_size =n_resp_tokens,
|
vocab_size =n_resp_tokens,
|
||||||
|
@ -575,9 +535,8 @@ class Base(nn.Module):
|
||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
num_local_experts=n_experts,
|
num_local_experts=n_experts,
|
||||||
num_experts_per_tok=min(2, n_experts),
|
num_experts_per_tok=min(2, n_experts),
|
||||||
attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2",
|
attn_implementation=attention,
|
||||||
))
|
))
|
||||||
|
|
||||||
elif self.arch_type == "retnet":
|
elif self.arch_type == "retnet":
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
vocab_size=n_resp_tokens,
|
vocab_size=n_resp_tokens,
|
||||||
|
@ -589,9 +548,9 @@ class Base(nn.Module):
|
||||||
dropout=p_dropout if training else 0.0,
|
dropout=p_dropout if training else 0.0,
|
||||||
checkpoint_activations=self.activation_checkpointing,
|
checkpoint_activations=self.activation_checkpointing,
|
||||||
activation_fn="gelu",
|
activation_fn="gelu",
|
||||||
use_layernorm=True, # self.version < 3,
|
use_layernorm=self.version < 3,
|
||||||
use_biases=True, # self.version < 3,
|
use_biases=self.version < 3,
|
||||||
use_glu=False, # self.version >= 3,
|
use_glu=self.version >= 3,
|
||||||
|
|
||||||
chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0,
|
chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0,
|
||||||
recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,
|
recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,
|
||||||
|
@ -642,6 +601,9 @@ class Base(nn.Module):
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
|
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
|
||||||
|
|
||||||
|
if use_xformers:
|
||||||
|
self.model = ml.replace_attention( self.model, klass=LLamaXformersAttention, target=LlamaAttention )
|
||||||
|
|
||||||
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
||||||
|
|
||||||
self.accuracy_metric = MulticlassAccuracy(
|
self.accuracy_metric = MulticlassAccuracy(
|
||||||
|
|
|
@ -9,6 +9,11 @@ from ..config import cfg
|
||||||
Embedding = torch.nn.Embedding
|
Embedding = torch.nn.Embedding
|
||||||
Linear = torch.nn.Linear
|
Linear = torch.nn.Linear
|
||||||
|
|
||||||
|
Adam = torch.optim.Adam
|
||||||
|
AdamW = torch.optim.AdamW
|
||||||
|
SGD = torch.optim.SGD
|
||||||
|
Adagrad = torch.optim.Adagrad
|
||||||
|
|
||||||
# https://github.com/kyegomez/BitNet
|
# https://github.com/kyegomez/BitNet
|
||||||
if cfg.optimizations.bitnet:
|
if cfg.optimizations.bitnet:
|
||||||
from bitnet import BitLinear
|
from bitnet import BitLinear
|
||||||
|
@ -37,19 +42,20 @@ if cfg.optimizations.bitsandbytes:
|
||||||
)).to(self.weight.dtype) )
|
)).to(self.weight.dtype) )
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if cfg.optimizations.optimizers:
|
||||||
|
Adam = bnb.optim.Adam8bit
|
||||||
|
AdamW = bnb.optim.AdamW8bit
|
||||||
|
SGD = bnb.optim.SGD8bit
|
||||||
|
Adagrad = bnb.optim.Adagrad8bit
|
||||||
|
|
||||||
if cfg.optimizations.bitsandbytes:
|
elif cfg.optimizations.dadaptation:
|
||||||
import bitsandbytes as bnb
|
import dadaptation
|
||||||
|
|
||||||
Adam = bnb.optim.Adam8bit
|
if cfg.optimizations.optimizers:
|
||||||
AdamW = bnb.optim.AdamW8bit
|
Adam = dadaptation.DAdaptAdam
|
||||||
SGD = bnb.optim.SGD8bit
|
AdamW = dadaptation.DAdaptAdam
|
||||||
Adagrad = bnb.optim.Adagrad8bit
|
SGD = dadaptation.DAdaptSGD
|
||||||
else:
|
AdaGrad = dadaptation.DAdaptAdaGrad
|
||||||
Adam = torch.optim.Adam
|
|
||||||
AdamW = torch.optim.AdamW
|
|
||||||
SGD = torch.optim.SGD
|
|
||||||
Adagrad = torch.optim.Adagrad
|
|
||||||
|
|
||||||
# handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16)
|
# handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16)
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
@ -92,42 +98,112 @@ else:
|
||||||
def autocast():
|
def autocast():
|
||||||
yield torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp)
|
yield torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp)
|
||||||
|
|
||||||
if cfg.optimizations.injects and cfg.optimizations.bitsandbytes:
|
if cfg.optimizations.injects:
|
||||||
torch.nn.Linear = Linear
|
if cfg.optimizations.linear:
|
||||||
torch.nn.Embedding = Embedding
|
torch.nn.Linear = Linear
|
||||||
|
|
||||||
|
if cfg.optimizations.embedding:
|
||||||
|
torch.nn.Embedding = Embedding
|
||||||
|
|
||||||
torch.optim.Adam = Adam
|
if cfg.optimizations.optimizers:
|
||||||
torch.optim.AdamW = AdamW
|
torch.optim.Adam = Adam
|
||||||
torch.optim.SGD = SGD
|
torch.optim.AdamW = AdamW
|
||||||
|
torch.optim.SGD = SGD
|
||||||
|
|
||||||
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
|
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
|
||||||
def replace_linear( model, verbose=False ):
|
# generalizing this would be super sugoi but the there's no catch all for arguments
|
||||||
|
def replace_linear( model, klass=Linear, target=torch.nn.Linear, verbose=False ):
|
||||||
bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet
|
bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet
|
||||||
|
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
linears = [k.split('.') for k, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
|
dtype = next(model.parameters()).dtype
|
||||||
klass = Linear
|
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
|
||||||
|
|
||||||
for *parent, k in linears:
|
for *parent, k in modules:
|
||||||
name = '.'.join(parent)
|
name = '.'.join(parent)
|
||||||
|
|
||||||
|
|
||||||
# copy parameters
|
|
||||||
m = getattr( model.get_submodule(name), k )
|
m = getattr( model.get_submodule(name), k )
|
||||||
|
|
||||||
if isinstance(m, klass):
|
if isinstance(m, klass):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
in_features = m.in_features
|
kwargs = dict(
|
||||||
out_features = m.out_features
|
in_features = m.in_features,
|
||||||
bias = m.bias is not None
|
out_features = m.out_features,
|
||||||
|
bias = m.bias is not None,
|
||||||
kwargs = dict(in_features=in_features, out_features=out_features, bias=bias) if not bnb else dict(input_features=in_features, output_features=out_features, bias=bias)
|
) if not bnb else dict(
|
||||||
|
input_features=m.in_features,
|
||||||
|
output_features=m.out_features,
|
||||||
|
bias=m.bias is not None,
|
||||||
|
)
|
||||||
|
|
||||||
# overwrite
|
# overwrite
|
||||||
setattr(
|
setattr(
|
||||||
model.get_submodule(name), k,
|
model.get_submodule(name), k,
|
||||||
klass( **kwargs ).to(device=device, dtype=cfg.trainer.dtype)
|
klass( **kwargs ).to(device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"Replacing {name}.{k} to", klass)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbose=False ):
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
dtype = next(model.parameters()).dtype
|
||||||
|
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
|
||||||
|
|
||||||
|
for *parent, k in modules:
|
||||||
|
name = '.'.join(parent)
|
||||||
|
|
||||||
|
m = getattr( model.get_submodule(name), k )
|
||||||
|
|
||||||
|
if isinstance(m, klass):
|
||||||
|
continue
|
||||||
|
|
||||||
|
kwargs = dict(
|
||||||
|
num_embeddings=m.num_embeddings,
|
||||||
|
embedding_dim=m.embedding_dim,
|
||||||
|
padding_idx=m.padding_idx,
|
||||||
|
max_norm=m.max_norm,
|
||||||
|
norm_type=m.norm_type,
|
||||||
|
scale_grad_by_freq=m.scale_grad_by_freq,
|
||||||
|
sparse=m.sparse,
|
||||||
|
)
|
||||||
|
|
||||||
|
# overwrite
|
||||||
|
setattr(
|
||||||
|
model.get_submodule(name), k,
|
||||||
|
klass( **kwargs ).to(device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"Replacing {name}.{k} to", klass)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
# cannot feasibly do default arguments here sad
|
||||||
|
def replace_attention( model, klass, target, verbose=False ):
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
dtype = next(model.parameters()).dtype
|
||||||
|
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
|
||||||
|
|
||||||
|
for *parent, k in modules:
|
||||||
|
name = '.'.join(parent)
|
||||||
|
|
||||||
|
m = getattr( model.get_submodule(name), k )
|
||||||
|
|
||||||
|
if isinstance(m, klass):
|
||||||
|
continue
|
||||||
|
|
||||||
|
kwargs = dict(
|
||||||
|
config = m.config,
|
||||||
|
layer_idx = m.layer_idx,
|
||||||
|
)
|
||||||
|
# overwrite
|
||||||
|
setattr(
|
||||||
|
model.get_submodule(name), k,
|
||||||
|
klass( **kwargs ).to(device=device, dtype=dtype)
|
||||||
)
|
)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
|
@ -139,4 +215,12 @@ def replace_linear( model, verbose=False ):
|
||||||
try:
|
try:
|
||||||
from prodigyopt import Prodigy
|
from prodigyopt import Prodigy
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print('Error while importing Prodigyopt:', str(e))
|
||||||
|
pass
|
||||||
|
|
||||||
|
# https://github.com/facebookresearch/schedule_free/
|
||||||
|
try:
|
||||||
|
import schedulefree
|
||||||
|
except Exception as e:
|
||||||
|
print('Error while importing Schedule_Free:', str(e))
|
||||||
pass
|
pass
|
Loading…
Reference in New Issue
Block a user