crammed in DAdaptation (doesn't seem worth it) and ScheduleFree (forgot I wanted to weeks ago, seems promising), optimization wrapper cleanup, test trainer changes, etc.

This commit is contained in:
mrq 2024-05-09 20:28:20 -05:00
parent c6e0f905b5
commit 0d5d545a40
7 changed files with 256 additions and 116 deletions

View File

@ -306,11 +306,14 @@ class Hyperparameters:
optimizer: str = "Adamw"
torch_optimizer: bool = False
optimizer_params: dict = field(default_factory=lambda: {})
learning_rate: float = 3.25e-4
scheduler_type: str = ""
scheduler: str = ""
scheduler_type: str = "" # deprecated
scheduler_params: dict = field(default_factory=lambda: {})
torch_scheduler: bool = False
@dataclass()
class Evaluation:
@ -337,7 +340,7 @@ class DeepSpeed:
for k in cfg.hyperparameters.scheduler_params:
scheduler_params[k] = cfg.hyperparameters.scheduler_params[k]
if cfg.hyperparameters.scheduler_type == "WarmupDecayLR" and 'total_num_steps' not in scheduler_params:
if cfg.hyperparameters.scheduler == "WarmupDecayLR" and 'total_num_steps' not in scheduler_params:
scheduler_params['total_num_steps'] = cfg.trainer.iterations
ds_cfg = {
@ -350,9 +353,9 @@ class DeepSpeed:
}
} if not cfg.hyperparameters.torch_optimizer else None,
"scheduler": {
"type": cfg.hyperparameters.scheduler_type,
"type": cfg.hyperparameters.scheduler,
"params": scheduler_params,
} if cfg.hyperparameters.scheduler_type != "" else None,
} if not cfg.hyperparameters.torch_scheduler else None,
"gradient_clipping": cfg.hyperparameters.gradient_clipping,
"fp16": {
"enabled": True,
@ -544,15 +547,17 @@ class Inference:
# should be renamed to optimizations
@dataclass()
class Optimizations:
bitsandbytes: bool = False
injects: bool = False
replace: bool = False
injects: bool = False # overwrites default torch classes (not recommended)
replace: bool = False # replaces modules in place with the optimized version (recommended)
linear: bool = True
embedding: bool = True
linear: bool = True # inject/replace linear for BnB
embedding: bool = True # inject/replace embedding for BnB
optimizers: bool = True # inject/replace optimizers (BnB, DAdaptation)
bitnet: bool = False
fp8: bool = False
bitsandbytes: bool = False # use bitsandbytes
dadaptation: bool = True # use dadaptation optimizer
bitnet: bool = False # use bitnet
fp8: bool = False # use fp8
@dataclass()
class Config(_Config):
@ -636,6 +641,17 @@ class Config(_Config):
else:
self.optimizations = Optimizations(**self.optimizations)
if self.hyperparameters.scheduler_type and not self.hyperparameters.scheduler:
self.hyperparameters.scheduler = self.hyperparameters.scheduler_type
self.hyperparameters.scheduler_type = ""
# do not combine the two
if self.hyperparameters.scheduler == "schedulefree" and self.optimizations.dadaptation:
self.hyperparameters.scheduler = ""
if self.hyperparameters.scheduler == "":
self.hyperparameters.torch_scheduler = True
# Preserves the old behavior
class NaiveTokenizer:
def get_vocab( self ):

View File

@ -379,6 +379,11 @@ class Dataset(_Dataset):
path = random.choice(choices)
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
if "audio" not in cfg.hdf5[key]:
_logger.warning("MISSING AUDIO:", key)
continue
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
else:
qnt = _load_quants(path)
@ -763,15 +768,15 @@ def create_dataset_metadata( skip_existing=True ):
name = str(dir)
name = name.replace(root, "")
# yucky
speaker_name = name
if "LbriTTS-R" in speaker_name:
speaker_name = speaker_name.replace("LbriTTS-R", "LibriVox")
metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
metadata_path.parents[0].mkdir(parents=True, exist_ok=True)
metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read())
try:
metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read())
except Exception as e:
metadata = {}
if not os.path.isdir(f'{root}/{name}/'):
return
@ -872,8 +877,8 @@ def create_dataset_hdf5( skip_existing=True ):
# yucky
speaker_name = name
if "LbriTTS-R" in speaker_name:
speaker_name = speaker_name.replace("LbriTTS-R", "LibriVox")
if "LibriTTS-R" in speaker_name:
speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox")
metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
metadata_path.parents[0].mkdir(parents=True, exist_ok=True)
@ -899,10 +904,8 @@ def create_dataset_hdf5( skip_existing=True ):
key = f'{type}/{speaker_name}/{id}'
"""
if skip_existing and key in hf:
continue
"""
group = hf.create_group(key) if key not in hf else hf[key]

View File

@ -143,7 +143,7 @@ def _load_vocos_model(device="cuda", levels=cfg.model.max_levels):
@cache
def _load_dac_model(device="cuda", levels=cfg.model.max_levels):
kwargs = dict(model_type="24khz",model_bitrate="8kbps",tag="latest")
kwargs = dict(model_type="44khz",model_bitrate="8kbps",tag="latest")
"""
if not cfg.variable_sample_rate:
# yes there's a better way, something like f'{cfg.sample.rate//1000}hz'

View File

@ -45,11 +45,16 @@ def load_engines(training=True):
if inferencing:
model._cfg.training = False
if (cfg.optimizations.bitsandbytes and cfg.optimizations.replace) or (cfg.optimizations.fp8):
if cfg.optimizations.replace and cfg.optimizations.linear:
model.model = ml.replace_linear( model.model )
if cfg.optimizations.replace and cfg.optimizations.embedding:
model.model = ml.replace_embedding( model.model )
if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
optimizer_class = None
scheduler_class = None
params = {
"lr": cfg.hyperparameters.learning_rate,
}
@ -58,6 +63,10 @@ def load_engines(training=True):
params["eps"] = 1e-07
params["weight_decay"] = 0.01
# for dadaptation since it has Adam only
if ml.AdamW == ml.Adam:
params["decouple"] = True
optimizer_class = ml.AdamW
elif cfg.hyperparameters.optimizer.lower() == "sgd":
optimizer = ml.SGD
@ -72,11 +81,27 @@ def load_engines(training=True):
raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
params.update(cfg.hyperparameters.optimizer_params)
optimizer = optimizer_class(
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
**params,
)
if cfg.hyperparameters.scheduler.lower() == "schedulefree":
if cfg.hyperparameters.optimizer.lower() == "adamw":
scheduler_class = ml.schedulefree.AdamWScheduleFree
elif cfg.hyperparameters.optimizer.lower() == "sgd":
scheduler_class = ml.schedulefree.SGDScheduleFree
else:
raise ValueError(f'ScheduleFree not implemented with requested optimizer: {cfg.hyperparameters.optimizer}')
optimizer = scheduler_class(
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
lr = params['lr']
)
# set up our LR scheduler here
if inferencing:

View File

@ -365,7 +365,7 @@ def example_usage():
'n_tokens': 1024,
'd_model': 1024, # 256, # 1024, # 1536
'n_heads': 16, # 4, # 16, # 24
'n_layers': 12, # 32
'n_layers': 4, # 32
'n_experts': 1,
'l_padding': 8 if cfg.optimizations.fp8 else 0,
@ -381,16 +381,66 @@ def example_usage():
"""
model = AR_NAR(**kwargs).to(device)
steps = 100
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
#optimizer = ml.Adagrad(model.parameters(), lr=1.0e-2)
#optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
steps = 1000
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else ""
learning_rate = cfg.hyperparameters.learning_rate if cfg.cfg_path is not None else None
if cfg.optimizations.dadaptation:
# do not combine the two
if scheduler == "schedulefree":
scheduler = ""
learning_rate = 1.0
if optimizer == "prodigy":
if learning_rate is None:
learning_rate = 1.0
optimizer = ml.Prodigy
elif optimizer == "adagrad":
if learning_rate is None:
learning_rate = 1.0e-2
optimizer = ml.Adagrad
elif optimizer == "adamw":
if learning_rate is None:
learning_rate = 1.0e-4
optimizer = ml.AdamW
elif optimizer == "sdg":
if learning_rate is None:
learning_rate = 1.0e-4
optimizer = ml.SGD
else:
raise ValueError(f"Unrecognized optimizer: {optimizer}")
print("Optimizer:", optimizer, "\tLearning rate:", learning_rate)
optimizer = optimizer(model.parameters(), lr=learning_rate)
if scheduler == "schedulefree":
if isinstance(optimizer, ml.AdamW):
scheduler = ml.schedulefree.AdamWScheduleFree
elif isinstance(optimizer, ml.SGD):
scheduler = ml.schedulefree.SGDScheduleFree
else:
scheduler = None
if scheduler is not None:
print("Scheduler:", scheduler)
optimizer = scheduler( model.parameters(), lr = learning_rate )
if cfg.optimizations.replace and cfg.optimizations.linear:
model = ml.replace_linear( model )
if cfg.optimizations.replace and cfg.optimizations.embedding:
model = ml.replace_embedding( model )
engine = Engine(model=model, optimizer=optimizer)
if (cfg.optimizations.bitsandbytes and cfg.optimizations.replace) or (cfg.optimizations.fp8):
model.model = ml.replace_linear( model.model )
torch.save( {
'module': model.state_dict()
}, "./data/test.pth" )

View File

@ -16,6 +16,8 @@ from torch.nn.utils.rnn import pad_sequence
from torch.utils.checkpoint import checkpoint
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
from ..utils import wrapper as ml
from ..samplers import reptition_penalize, length_penalize, top_k_top_p_filtering, dynamic_temperature, top_k_logits_list, mirostat_sample
try:
@ -191,48 +193,9 @@ try:
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value
LLAMA_ATTENTIONS["xformers"] = LLamaXformersAttention
except Exception as e:
print("Error creating `LLamaXformersAttention`:", e)
def replace_attention( model, impl, verbose=False ):
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
attentions = [k.split('.') for k, m in model.named_modules() if isinstance(m, LlamaAttention)]
if impl not in LLAMA_ATTENTIONS:
print(f"Attention '{imp} is not in LLAMA_ATTENTIONS'")
return model
klass = LLAMA_ATTENTIONS[impl]
for *parent, k in attentions:
name = '.'.join(parent)
# copy parameters
m = getattr( model.get_submodule(name), k )
if isinstance(m, klass):
continue
config = m.config
layer_idx = m.layer_idx
kwargs = dict(config=config, layer_idx=layer_idx)
# overwrite
setattr(
model.get_submodule(name), k,
klass( **kwargs ).to(device=device, dtype=dtype)
)
if verbose:
print(f"Replacing {name}.{k} to", klass)
return model
def _create_mask(l, device):
"""1 is valid region and 0 is invalid."""
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
@ -485,6 +448,14 @@ class Base(nn.Module):
self.sep = nn.Parameter(torch.randn(d_model))
# ick, there has to be a better way
attention = self.config.attention if self.config is not None else None
use_xformers = False
if attention == "xformers":
use_xformers = True
attention = None
if self.arch_type == "transformer":
self.sin_emb = SinusoidalEmbedding(d_model)
self.blocks = nn.ModuleList([TransformerBlock(
@ -495,7 +466,7 @@ class Base(nn.Module):
norm_type=self.norm_type,
n_levels=self.n_resp_levels,
) for _ in range(n_layers) ])
elif self.arch_type == "mistral" or self.arch_type == "mixtral":
elif self.arch_type in ["mistral", "mixtral"]:
if n_experts <= 1:
self.model = MistralModel(MistralConfig(
vocab_size=n_resp_tokens,
@ -509,7 +480,7 @@ class Base(nn.Module):
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2",
attn_implementation=attention,
))
else:
self.model = MixtralModel(MixtralConfig(
@ -528,18 +499,10 @@ class Base(nn.Module):
is_decoder=True,
num_local_experts=n_experts,
num_experts_per_tok=min(2, n_experts),
attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2",
attn_implementation=attention,
))
elif self.arch_type == "llama":
if n_experts <= 1:
# ick, there has to be a better way
attention = self.config.attention if self.config is not None else None # "flash_attention_2",
use_xformers = False
if attention == "xformers":
use_xformers = True
attention = None
self.model = LlamaModel(LlamaConfig(
vocab_size=n_resp_tokens,
hidden_size=d_model,
@ -555,9 +518,6 @@ class Base(nn.Module):
is_decoder=True,
attn_implementation=attention,
))
if use_xformers:
self.model = replace_attention( self.model, "xformers" if use_xformers else attention )
else:
self.model = MixtralModel(MixtralConfig(
vocab_size =n_resp_tokens,
@ -575,9 +535,8 @@ class Base(nn.Module):
is_decoder=True,
num_local_experts=n_experts,
num_experts_per_tok=min(2, n_experts),
attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2",
attn_implementation=attention,
))
elif self.arch_type == "retnet":
kwargs = dict(
vocab_size=n_resp_tokens,
@ -589,9 +548,9 @@ class Base(nn.Module):
dropout=p_dropout if training else 0.0,
checkpoint_activations=self.activation_checkpointing,
activation_fn="gelu",
use_layernorm=True, # self.version < 3,
use_biases=True, # self.version < 3,
use_glu=False, # self.version >= 3,
use_layernorm=self.version < 3,
use_biases=self.version < 3,
use_glu=self.version >= 3,
chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0,
recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,
@ -642,6 +601,9 @@ class Base(nn.Module):
else:
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
if use_xformers:
self.model = ml.replace_attention( self.model, klass=LLamaXformersAttention, target=LlamaAttention )
self.classifier = nn.Linear(d_model, n_resp_tokens)
self.accuracy_metric = MulticlassAccuracy(

View File

@ -9,6 +9,11 @@ from ..config import cfg
Embedding = torch.nn.Embedding
Linear = torch.nn.Linear
Adam = torch.optim.Adam
AdamW = torch.optim.AdamW
SGD = torch.optim.SGD
Adagrad = torch.optim.Adagrad
# https://github.com/kyegomez/BitNet
if cfg.optimizations.bitnet:
from bitnet import BitLinear
@ -37,19 +42,20 @@ if cfg.optimizations.bitsandbytes:
)).to(self.weight.dtype) )
"""
if cfg.optimizations.optimizers:
Adam = bnb.optim.Adam8bit
AdamW = bnb.optim.AdamW8bit
SGD = bnb.optim.SGD8bit
Adagrad = bnb.optim.Adagrad8bit
if cfg.optimizations.bitsandbytes:
import bitsandbytes as bnb
elif cfg.optimizations.dadaptation:
import dadaptation
Adam = bnb.optim.Adam8bit
AdamW = bnb.optim.AdamW8bit
SGD = bnb.optim.SGD8bit
Adagrad = bnb.optim.Adagrad8bit
else:
Adam = torch.optim.Adam
AdamW = torch.optim.AdamW
SGD = torch.optim.SGD
Adagrad = torch.optim.Adagrad
if cfg.optimizations.optimizers:
Adam = dadaptation.DAdaptAdam
AdamW = dadaptation.DAdaptAdam
SGD = dadaptation.DAdaptSGD
AdaGrad = dadaptation.DAdaptAdaGrad
# handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16)
@contextmanager
@ -92,42 +98,112 @@ else:
def autocast():
yield torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp)
if cfg.optimizations.injects and cfg.optimizations.bitsandbytes:
torch.nn.Linear = Linear
torch.nn.Embedding = Embedding
if cfg.optimizations.injects:
if cfg.optimizations.linear:
torch.nn.Linear = Linear
if cfg.optimizations.embedding:
torch.nn.Embedding = Embedding
torch.optim.Adam = Adam
torch.optim.AdamW = AdamW
torch.optim.SGD = SGD
if cfg.optimizations.optimizers:
torch.optim.Adam = Adam
torch.optim.AdamW = AdamW
torch.optim.SGD = SGD
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
def replace_linear( model, verbose=False ):
# generalizing this would be super sugoi but the there's no catch all for arguments
def replace_linear( model, klass=Linear, target=torch.nn.Linear, verbose=False ):
bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet
device = next(model.parameters()).device
linears = [k.split('.') for k, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
klass = Linear
dtype = next(model.parameters()).dtype
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
for *parent, k in linears:
for *parent, k in modules:
name = '.'.join(parent)
# copy parameters
m = getattr( model.get_submodule(name), k )
if isinstance(m, klass):
continue
in_features = m.in_features
out_features = m.out_features
bias = m.bias is not None
kwargs = dict(in_features=in_features, out_features=out_features, bias=bias) if not bnb else dict(input_features=in_features, output_features=out_features, bias=bias)
kwargs = dict(
in_features = m.in_features,
out_features = m.out_features,
bias = m.bias is not None,
) if not bnb else dict(
input_features=m.in_features,
output_features=m.out_features,
bias=m.bias is not None,
)
# overwrite
setattr(
model.get_submodule(name), k,
klass( **kwargs ).to(device=device, dtype=cfg.trainer.dtype)
klass( **kwargs ).to(device=device, dtype=dtype)
)
if verbose:
print(f"Replacing {name}.{k} to", klass)
return model
def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbose=False ):
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
for *parent, k in modules:
name = '.'.join(parent)
m = getattr( model.get_submodule(name), k )
if isinstance(m, klass):
continue
kwargs = dict(
num_embeddings=m.num_embeddings,
embedding_dim=m.embedding_dim,
padding_idx=m.padding_idx,
max_norm=m.max_norm,
norm_type=m.norm_type,
scale_grad_by_freq=m.scale_grad_by_freq,
sparse=m.sparse,
)
# overwrite
setattr(
model.get_submodule(name), k,
klass( **kwargs ).to(device=device, dtype=dtype)
)
if verbose:
print(f"Replacing {name}.{k} to", klass)
return model
# cannot feasibly do default arguments here sad
def replace_attention( model, klass, target, verbose=False ):
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
for *parent, k in modules:
name = '.'.join(parent)
m = getattr( model.get_submodule(name), k )
if isinstance(m, klass):
continue
kwargs = dict(
config = m.config,
layer_idx = m.layer_idx,
)
# overwrite
setattr(
model.get_submodule(name), k,
klass( **kwargs ).to(device=device, dtype=dtype)
)
if verbose:
@ -139,4 +215,12 @@ def replace_linear( model, verbose=False ):
try:
from prodigyopt import Prodigy
except Exception as e:
print('Error while importing Prodigyopt:', str(e))
pass
# https://github.com/facebookresearch/schedule_free/
try:
import schedulefree
except Exception as e:
print('Error while importing Schedule_Free:', str(e))
pass