added muon optimizer through kludge hacks because it necessitates a second optimizer in tandum that seems to only sometimes work with deepspeed
This commit is contained in:
parent
67a6009555
commit
6634d07576
|
@ -11,7 +11,7 @@ elif cfg.trainer.backend == "local":
|
|||
from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine
|
||||
|
||||
from ..models import get_models, get_model
|
||||
from ..utils import wrapper as ml
|
||||
from ..utils import ml
|
||||
from ..utils.io import torch_save, torch_load, pick_path
|
||||
from ..models.lora import apply_lora, lora_load_state_dict
|
||||
|
||||
|
@ -114,7 +114,6 @@ def load_engines(training=True, **model_kwargs):
|
|||
"lr": cfg.hyperparameters.learning_rate,
|
||||
}
|
||||
|
||||
|
||||
if cfg.hyperparameters.optimizer.lower() == "adamw":
|
||||
params["betas"] = (0.9, 0.96)
|
||||
params["eps"] = 1e-07
|
||||
|
@ -149,8 +148,21 @@ def load_engines(training=True, **model_kwargs):
|
|||
else:
|
||||
raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
|
||||
|
||||
muon_params = cfg.hyperparameters.optimizer_params.pop("muon", None)
|
||||
params.update(cfg.hyperparameters.optimizer_params)
|
||||
optimizer = optimizer_class(**params)
|
||||
|
||||
if muon_params is not None:
|
||||
muon_params["params"] = [ param for name, param in model.model.named_parameters() if param.ndim >= 2 and f'model.{name}' not in model.config.frozen_params ]
|
||||
|
||||
params["params"] = [ param for name, param in model.model.named_parameters() if param.ndim < 2 and f'model.{name}' not in model.config.frozen_params ]
|
||||
params["params"] += [ param for name, param in model.named_parameters() if not name.startswith('model.') and name not in model.config.frozen_params ]
|
||||
|
||||
optimizer = ml.Optimizers([
|
||||
ml.Muon(**muon_params),
|
||||
optimizer_class(**params),
|
||||
])
|
||||
else:
|
||||
optimizer = optimizer_class(**params)
|
||||
|
||||
if cfg.hyperparameters.scheduler.lower() == "schedulefree":
|
||||
if cfg.hyperparameters.optimizer.lower() == "adamw":
|
||||
|
@ -406,7 +418,7 @@ def load_engines(training=True, **model_kwargs):
|
|||
try:
|
||||
engine.wandb = wandb.init(project=key_name, **kwargs)
|
||||
engine.wandb.watch(engine.module)
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
engine.wandb = None
|
||||
else:
|
||||
engine.wandb = None
|
||||
|
|
|
@ -45,7 +45,7 @@ from typing import Any, Protocol
|
|||
from functools import cached_property
|
||||
|
||||
from .base import TrainFeeder
|
||||
from ..utils import wrapper as ml
|
||||
from ..utils import ml
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ from deepspeed import DeepSpeedEngine, DeepSpeedConfig, comm as dist, init_distr
|
|||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
from ..utils.distributed import init_distributed, distributed_initialized
|
||||
from ..utils import wrapper as ml
|
||||
from ..utils import ml
|
||||
|
||||
from ..models.lora import freeze_non_lora_weights
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ from .emb import g2p, qnt
|
|||
from .emb.qnt import trim, trim_random, unload_model, repeat_extend_audio
|
||||
from .emb.transcribe import transcribe
|
||||
|
||||
from .utils import to_device, set_seed, clamp, wrapper as ml
|
||||
from .utils import to_device, set_seed, clamp, ml
|
||||
|
||||
from .config import cfg, Config
|
||||
from .models import get_models
|
||||
|
|
|
@ -1189,8 +1189,13 @@ class AR_NAR(Base):
|
|||
break
|
||||
|
||||
for i, l in enumerate( sequence_list ):
|
||||
index = (l == audio_stop_token).nonzero()[:, 0].min()
|
||||
sequence_list[i] = sequence_list[i][:index]
|
||||
index = (l == audio_stop_token).nonzero()
|
||||
# kludge for when it doesnt actually hit a stop token but i cant be bothered to properly address it right now since it only came up in test training at the moment
|
||||
try:
|
||||
index = index[:, 0].min()
|
||||
sequence_list[i] = sequence_list[i][:index]
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
return sequence_list
|
||||
|
||||
|
@ -1362,8 +1367,6 @@ class AR_NAR(Base):
|
|||
def example_usage():
|
||||
cfg.device = "cuda"
|
||||
cfg.trainer.backend = "local"
|
||||
if cfg.audio_backend == "dac":
|
||||
cfg.sample_rate = 44_100
|
||||
|
||||
from functools import partial
|
||||
from einops import repeat
|
||||
|
@ -1372,7 +1375,7 @@ def example_usage():
|
|||
from ..emb.qnt import decode_to_file, unload_model, trim_random, repeat_extend_audio, concat_audio, merge_audio
|
||||
from ..data import _load_artifact
|
||||
from ..engines import Engine, Engines
|
||||
from ..utils import wrapper as ml
|
||||
from ..utils import ml
|
||||
from ..utils import setup_logging
|
||||
|
||||
import numpy as np
|
||||
|
@ -1403,7 +1406,7 @@ def example_usage():
|
|||
'n_text_tokens': cfg.model.text_tokens,
|
||||
'n_audio_tokens': cfg.model.audio_tokens,
|
||||
|
||||
'd_model': 1536, # 256, # 1024, # 1536
|
||||
'd_model': 1024, # 256, # 1024, # 1536
|
||||
'n_heads': 16, # 4, # 16, # 24
|
||||
'n_layers': 12, # 32
|
||||
'n_experts': 1 if not cfg.model else cfg.model.experts,
|
||||
|
@ -1425,7 +1428,7 @@ def example_usage():
|
|||
available_tasks = ["tts-nar"]
|
||||
|
||||
model = AR_NAR(**kwargs).to(cfg.device)
|
||||
steps = 250 // batch_size
|
||||
steps = 100 // batch_size
|
||||
|
||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||
|
@ -1464,28 +1467,23 @@ def example_usage():
|
|||
learning_rate = 0.01
|
||||
|
||||
optimizer = ml.Apollo
|
||||
|
||||
"""
|
||||
target_params = []
|
||||
target_modules_list = ["attn", "mlp"]
|
||||
for module_name, module in model.named_modules():
|
||||
if not (isinstance(module, torch.nn.Linear)):
|
||||
continue
|
||||
if not any(target_key in module_name for target_key in target_modules_list):
|
||||
continue
|
||||
target_params.append(module.weight)
|
||||
|
||||
param_ids = [id(p) for p in target_params]
|
||||
regular_params = [p for p in model.parameters() if id(p) not in param_ids]
|
||||
params = [{'params': regular_params}, {'params': target_params, 'rank': 1, 'proj': 'random', 'scale_type': 'tensor', 'scale': 128,'update_proj_gap': 200, 'proj_type': 'std'}]
|
||||
"""
|
||||
params = [{'params': params, 'rank': 1, 'proj': 'random', 'scale_type': 'tensor', 'scale': 128,'update_proj_gap': 200, 'proj_type': 'std'}]
|
||||
else:
|
||||
raise ValueError(f"Unrecognized optimizer: {optimizer}")
|
||||
|
||||
_logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}")
|
||||
|
||||
optimizer = optimizer(params, lr=learning_rate)
|
||||
|
||||
muon_params = cfg.hyperparameters.optimizer_params.pop("muon", None)
|
||||
if muon_params is not None:
|
||||
muon_params["params"] = [ param for name, param in model.model.named_parameters() if param.ndim >= 2 ]
|
||||
adam_params = [ param for name, param in model.model.named_parameters() if param.ndim < 2 ] + [ param for name, param in model.named_parameters() if not name.startswith('model.') ]
|
||||
|
||||
optimizer = ml.Optimizers([
|
||||
ml.Muon(**muon_params),
|
||||
optimizer(adam_params, lr=learning_rate)
|
||||
])
|
||||
else:
|
||||
optimizer = optimizer(params, lr=learning_rate)
|
||||
|
||||
if scheduler == "schedulefree":
|
||||
if isinstance(optimizer, ml.AdamW):
|
||||
|
|
|
@ -14,7 +14,7 @@ from einops import rearrange
|
|||
from torch import Tensor, einsum, nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from ...utils import wrapper as ml
|
||||
from ...utils import ml
|
||||
|
||||
class AdaLN(nn.Module):
|
||||
def __init__(self, d_model, n_levels, eps=1e-5, k=0.1, c=2):
|
||||
|
|
|
@ -30,7 +30,7 @@ from torch.utils.checkpoint import checkpoint
|
|||
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
|
||||
|
||||
from .arch import *
|
||||
from ..utils import wrapper as ml, clamp
|
||||
from ..utils import ml, clamp
|
||||
from ..samplers import *
|
||||
|
||||
# yuck, kind of needed
|
||||
|
|
|
@ -85,22 +85,6 @@ if cfg.optimizations.injects:
|
|||
torch.optim.AdamW = AdamW
|
||||
torch.optim.SGD = SGD
|
||||
|
||||
AVAILABLE_COMPILE_BACKENDS = []
|
||||
|
||||
try:
|
||||
AVAILABLE_COMPILE_BACKENDS += torch._dynamo.list_backends()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
if cfg.optimizations.tensorrt:
|
||||
try:
|
||||
import torch_tensorrt
|
||||
AVAILABLE_COMPILE_BACKENDS.append("tensorrt")
|
||||
except Exception as e:
|
||||
_logger.warning(f'Error while importing TensorRT: {str(e)}')
|
||||
pass
|
||||
|
||||
if cfg.optimizations.unsloth:
|
||||
try:
|
||||
from .ext.unsloth import apply_unsloth_offloaded_gradient_checkpoint_monkey_patch
|
||||
|
@ -109,20 +93,48 @@ if cfg.optimizations.unsloth:
|
|||
_logger.warning(f'Error while importing Unsloth: {str(e)}')
|
||||
pass
|
||||
|
||||
class Optimizers(torch.optim.Optimizer):
|
||||
def __init__(self, opts):
|
||||
self.opts = opts
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
for opt in self.opts:
|
||||
opt.step(*args, **kwargs)
|
||||
|
||||
def zero_grad(self, *args, **kwargs):
|
||||
for opt in self.opts:
|
||||
opt.zero_grad(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def param_groups(self):
|
||||
l = []
|
||||
for opt in self.opts:
|
||||
l += opt.param_groups
|
||||
return l
|
||||
|
||||
def state_dict(self):
|
||||
states = []
|
||||
for i, opt in enumerate( self.opts ):
|
||||
states.append( opt.state_dict() )
|
||||
|
||||
return states
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
for opt, state in zip( self.opts, state_dict ):
|
||||
opt.load_state_dict( state )
|
||||
|
||||
try:
|
||||
from .ext.apollo import Apollo
|
||||
except Exception as e:
|
||||
_logger.warning(f'Error while importing APOLLO: {str(e)}')
|
||||
pass
|
||||
|
||||
def compile_model(model, backend="auto"):
|
||||
if not backend or backend == "auto":
|
||||
backend = AVAILABLE_COMPILE_BACKENDS[0]
|
||||
|
||||
if backend not in AVAILABLE_COMPILE_BACKENDS:
|
||||
return torch.compile(model)
|
||||
|
||||
return torch.compile(model, backend=backend)
|
||||
try:
|
||||
from muon import Muon as Muon
|
||||
except Exception as e:
|
||||
raise e
|
||||
#_logger.warning(f'Error while importing Muon: {str(e)}')
|
||||
#pass
|
||||
|
||||
# https://github.com/konstmish/prodigy
|
||||
try:
|
||||
|
@ -154,4 +166,29 @@ def replace_linear( model, klass=Linear, target=torch.nn.Linear, verbose=False )
|
|||
def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbose=False ):
|
||||
return replace_embedding_old( model, klass, target, verbose )
|
||||
|
||||
Embedding.forward = autocast_forward(Embedding.forward)
|
||||
Embedding.forward = autocast_forward(Embedding.forward)
|
||||
|
||||
AVAILABLE_COMPILE_BACKENDS = []
|
||||
|
||||
try:
|
||||
AVAILABLE_COMPILE_BACKENDS += torch._dynamo.list_backends()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
def compile_model(model, backend="auto"):
|
||||
if not backend or backend == "auto":
|
||||
backend = AVAILABLE_COMPILE_BACKENDS[0]
|
||||
|
||||
if backend not in AVAILABLE_COMPILE_BACKENDS:
|
||||
return torch.compile(model)
|
||||
|
||||
return torch.compile(model, backend=backend)
|
||||
|
||||
|
||||
if cfg.optimizations.tensorrt:
|
||||
try:
|
||||
import torch_tensorrt
|
||||
AVAILABLE_COMPILE_BACKENDS.append("tensorrt")
|
||||
except Exception as e:
|
||||
_logger.warning(f'Error while importing TensorRT: {str(e)}')
|
||||
pass
|
|
@ -33,7 +33,7 @@ from .distributed import (
|
|||
from ..engines import Engine, Engines, TrainFeeder, default_feeder, load_engines
|
||||
|
||||
from .utils import to_device, do_gc, truncate_json
|
||||
from ..utils import wrapper as ml
|
||||
from ..utils import ml
|
||||
from ..data import get_phone_symmap # should decouple from this trainer script
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
|
|
@ -53,7 +53,7 @@ def md5_hash( x ):
|
|||
return hashlib.md5(str(x).encode("utf-8")).hexdigest()
|
||||
|
||||
# removes entries from a dict if that key is missing from the source
|
||||
def prune_missing( source, dest, recurse=True, path=[], parent_is_obj=None, return_missing=True ):
|
||||
def prune_missing( source, dest, recurse=True, path=[], parent_is_obj=None, return_missing=True, ignore=["optimizer_params"] ):
|
||||
is_obj = hasattr( source, "__dict__" )
|
||||
if parent_is_obj is None:
|
||||
parent_is_obj = is_obj
|
||||
|
@ -65,6 +65,9 @@ def prune_missing( source, dest, recurse=True, path=[], parent_is_obj=None, retu
|
|||
keep[k] = dest[k]
|
||||
else:
|
||||
missing.append(".".join(path + [k]))
|
||||
|
||||
if k in ignore:
|
||||
continue
|
||||
|
||||
if recurse and isinstance( v, dict ):
|
||||
keep[k], m = prune_missing( haystack[k], dest[k], path=path + [k], parent_is_obj=parent_is_obj, return_missing=return_missing )
|
||||
|
|
Loading…
Reference in New Issue
Block a user