made muon actually work by actually utilizing param groups (thanks APOLLO for reminding me this is the sane way to handle this split)
This commit is contained in:
parent
de27115bb7
commit
95da4e9405
|
@ -19,10 +19,9 @@ A training paradigm that works for me is:
|
||||||
* additional training for sampling per speaker, to better help diversify how well it can perform for a range of speakers, rather than just speaking itself
|
* additional training for sampling per speaker, to better help diversify how well it can perform for a range of speakers, rather than just speaking itself
|
||||||
* I don't think this is crucial, but speaker-based sampling seems to be a placebo if anything.
|
* I don't think this is crucial, but speaker-based sampling seems to be a placebo if anything.
|
||||||
|
|
||||||
Training under `float16` should be fairly simple, but care is required to keep the loss scaling factor above 8K, and probably even 16K.
|
Training under `float16` (+AMP) should be fairly simple, but it's practically required to use the `deepspeed` backend.
|
||||||
* At the very least for pre-trained models, low enough loss scales will irreparably fry the model, and no amount of training afterwards seems to "fix" it.
|
* This is because `deepspeed` will automatically wrap the optimizer to handle training under `float16`, while the `local` backend does not do this. Training will *not* converge.
|
||||||
* The current DeepSpeed configuration should keep the loss scale capped to 32K; normal training does not seem to have the loss scale ever want to dip below this at least.
|
* Training under `bfloat16` does not have to worry about this.
|
||||||
* Training under `bfloat16` does not have to worry about this as there's no need for loss scaling, but I feel the model performs better when trained under `float16`+AMP rather than `bfloat16` (with or without AMP).
|
|
||||||
|
|
||||||
When training from scratch, maybe 30% of the time spent training is getting coherent speech, with a loose following of the prompt. The remaining bulk of the work is getting the model to closely-er resemble the input prompt.
|
When training from scratch, maybe 30% of the time spent training is getting coherent speech, with a loose following of the prompt. The remaining bulk of the work is getting the model to closely-er resemble the input prompt.
|
||||||
* an accuracy of at least 50% seems to be where coherent speech emerges.
|
* an accuracy of at least 50% seems to be where coherent speech emerges.
|
||||||
|
|
|
@ -146,42 +146,22 @@ def load_engines(training=True, **model_kwargs):
|
||||||
elif cfg.hyperparameters.optimizer.lower() == "adagrad":
|
elif cfg.hyperparameters.optimizer.lower() == "adagrad":
|
||||||
optimizer_class = ml.Adagrad
|
optimizer_class = ml.Adagrad
|
||||||
elif cfg.hyperparameters.optimizer.lower() == "muon":
|
elif cfg.hyperparameters.optimizer.lower() == "muon":
|
||||||
del params["params"]
|
optimizer = ml.Muon
|
||||||
optimizer_class = ml.Muon
|
|
||||||
|
|
||||||
|
muon_params = [ param for name, param in model.model.named_parameters() if param.ndim >= 2 ]
|
||||||
|
adamw_params = [ param for name, param in model.model.named_parameters() if param.ndim < 2 ]
|
||||||
|
adamw_params += [ param for name, param in model.named_parameters() if not name.startswith('model.') ]
|
||||||
|
|
||||||
params["muon_params"] = [ param for name, param in model.model.named_parameters() if param.ndim >= 2 and f'model.{name}' not in model.config.frozen_params ]
|
params["params"] = [
|
||||||
params["adamw_params"] = [ param for name, param in model.model.named_parameters() if param.ndim < 2 and f'model.{name}' not in model.config.frozen_params ]
|
{ "params": muon_params, "muon": True },
|
||||||
params["adamw_params"] += [ param for name, param in model.named_parameters() if not name.startswith('model.') and name not in model.config.frozen_params ]
|
{ "params": adamw_params, "muon": False, "betas": (0.95, 0.95), "eps": 1e-8 },
|
||||||
|
]
|
||||||
if cfg.hyperparameters.optimizer_params is not None:
|
|
||||||
params["adamw_betas"] = cfg.hyperparameters.optimizer_params.pop("adamw_betas", (0.95, 0.95))
|
|
||||||
params["adamw_eps"] = cfg.hyperparameters.optimizer_params.pop("adamw_eps", 1e-8)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
|
raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
|
||||||
|
|
||||||
params.update(cfg.hyperparameters.optimizer_params)
|
params.update(cfg.hyperparameters.optimizer_params)
|
||||||
optimizer = optimizer_class(**params)
|
optimizer = optimizer_class(**params)
|
||||||
|
|
||||||
"""
|
|
||||||
if cfg.hyperparameters.optimizer_params is not None:
|
|
||||||
muon_params = cfg.hyperparameters.optimizer_params.pop("muon", None)
|
|
||||||
params.update(cfg.hyperparameters.optimizer_params)
|
|
||||||
|
|
||||||
if muon_params is not None:
|
|
||||||
muon_params["params"] = [ param for name, param in model.model.named_parameters() if param.ndim >= 2 and f'model.{name}' not in model.config.frozen_params ]
|
|
||||||
|
|
||||||
params["params"] = [ param for name, param in model.model.named_parameters() if param.ndim < 2 and f'model.{name}' not in model.config.frozen_params ]
|
|
||||||
params["params"] += [ param for name, param in model.named_parameters() if not name.startswith('model.') and name not in model.config.frozen_params ]
|
|
||||||
|
|
||||||
optimizer = ml.Optimizers([
|
|
||||||
ml.Muon(**muon_params),
|
|
||||||
optimizer_class(**params),
|
|
||||||
])
|
|
||||||
else:
|
|
||||||
optimizer = optimizer_class(**params)
|
|
||||||
"""
|
|
||||||
|
|
||||||
if cfg.hyperparameters.scheduler.lower() == "schedulefree":
|
if cfg.hyperparameters.scheduler.lower() == "schedulefree":
|
||||||
if cfg.hyperparameters.optimizer.lower() == "adamw":
|
if cfg.hyperparameters.optimizer.lower() == "adamw":
|
||||||
scheduler_class = ml.schedulefree.AdamWScheduleFree
|
scheduler_class = ml.schedulefree.AdamWScheduleFree
|
||||||
|
@ -233,81 +213,9 @@ def load_engines(training=True, **model_kwargs):
|
||||||
for k in erase:
|
for k in erase:
|
||||||
del state[k]
|
del state[k]
|
||||||
|
|
||||||
# converts an AR+NAR model into an AR+NAR-len model
|
|
||||||
"""
|
|
||||||
if True:
|
|
||||||
# move STT one over
|
|
||||||
state['classifiers.proj.9.weight'] = state['classifiers.proj.8.weight'].clone()
|
|
||||||
state['classifiers.proj.9.bias'] = state['classifiers.proj.8.bias'].clone()
|
|
||||||
# copy from AR:0:0 classifier
|
|
||||||
if True:
|
|
||||||
state['classifiers.proj.8.weight'] = state['classifiers.proj.0.weight'].clone()
|
|
||||||
state['classifiers.proj.8.bias'] = state['classifiers.proj.0.bias'].clone()
|
|
||||||
# copy from AR:0:0 embeddings
|
|
||||||
state['resps_emb.embeddings.8.weight'] = state['resps_emb.embeddings.0.weight'].clone()
|
|
||||||
# remove
|
|
||||||
else:
|
|
||||||
if 'classifiers.proj.8.weight' in state:
|
|
||||||
del state['classifiers.proj.8.weight']
|
|
||||||
if 'classifiers.proj.8.bias' in state:
|
|
||||||
del state['classifiers.proj.8.bias']
|
|
||||||
if 'resps_emb.embeddings.8.weight' in state:
|
|
||||||
del state['resps_emb.embeddings.8.weight']
|
|
||||||
"""
|
|
||||||
|
|
||||||
# resize modules if I'm doing experiments and can't be assed to manually trim things
|
# resize modules if I'm doing experiments and can't be assed to manually trim things
|
||||||
if cfg.trainer.resize_modules:
|
if cfg.trainer.resize_modules:
|
||||||
uses_stop_token = 1 if ("ar" in model.capabilities or "len" in model.capabilities) > 0 else 0
|
keys = []
|
||||||
keys = [
|
|
||||||
("text_emb.weight", model.config.text_tokens ),
|
|
||||||
("tasks_emb.weight", model.config.tasks ),
|
|
||||||
("langs_emb.weight", model.config.langs ),
|
|
||||||
("rvq_l_emb.weight", model.config.resp_levels ),
|
|
||||||
("resps_emb.embeddings.0.weight", model.config.audio_tokens + uses_stop_token ),
|
|
||||||
("model.embed_tokens.weight", model.config.audio_tokens + uses_stop_token ),
|
|
||||||
("classifiers.proj.0.weight", model.config.audio_tokens + uses_stop_token ),
|
|
||||||
("classifiers.proj.0.bias", model.config.audio_tokens + uses_stop_token ),
|
|
||||||
("classifier.weight", model.n_vocab ),
|
|
||||||
("classifier.bias", model.n_vocab ),
|
|
||||||
]
|
|
||||||
|
|
||||||
last_embedding_keys = {}
|
|
||||||
|
|
||||||
# correcting an oversight
|
|
||||||
"""
|
|
||||||
if model.config.experimental.split_classifiers and "len" in model.capabilities:
|
|
||||||
len_idx, nar_0_idx = model.classifiers.indices(["len", "NAR:0:0"])
|
|
||||||
keys.append((f"classifiers.proj.{len_idx}.weight", 11))
|
|
||||||
keys.append((f"classifiers.proj.{len_idx}.bias", 11))
|
|
||||||
|
|
||||||
keys.append((f"classifiers.proj.{nar_0_idx}.weight", model.config.audio_tokens))
|
|
||||||
keys.append((f"classifiers.proj.{nar_0_idx}.bias", model.config.audio_tokens))
|
|
||||||
"""
|
|
||||||
|
|
||||||
# correcting an oversight
|
|
||||||
"""
|
|
||||||
if True:
|
|
||||||
keys.append((f"classifiers.proj.0.weight", model.config.audio_tokens+1))
|
|
||||||
for i in range(1,9):
|
|
||||||
keys.append((f"classifiers.proj.{i}.weight", model.config.audio_tokens))
|
|
||||||
|
|
||||||
keys.append((f"resps_emb.embeddings.0.weight", model.config.audio_tokens+1))
|
|
||||||
keys.append((f"resps_emb.embeddings.8.weight", model.config.audio_tokens+1))
|
|
||||||
|
|
||||||
for i in range(1,8):
|
|
||||||
keys.append((f"resps_emb.embeddings.{i}.weight", model.config.audio_tokens))
|
|
||||||
|
|
||||||
for i in range(8):
|
|
||||||
keys.append((f"proms_emb.embeddings.{i}.weight", model.config.audio_tokens))
|
|
||||||
|
|
||||||
last_embedding_keys = {
|
|
||||||
"classifiers.proj.0.weight": state["classifiers.proj.0.weight"][-1].clone().detach(),
|
|
||||||
"resps_emb.embeddings.0.weight": state["resps_emb.embeddings.0.weight"][-1].clone().detach(),
|
|
||||||
"resps_emb.embeddings.8.weight": state["resps_emb.embeddings.8.weight"][-1].clone().detach(),
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
for k, tokens in keys:
|
for k, tokens in keys:
|
||||||
if k not in state:
|
if k not in state:
|
||||||
continue
|
continue
|
||||||
|
@ -316,50 +224,6 @@ def load_engines(training=True, **model_kwargs):
|
||||||
for k, v in last_embedding_keys.items():
|
for k, v in last_embedding_keys.items():
|
||||||
state[k][-1] = v
|
state[k][-1] = v
|
||||||
|
|
||||||
# stuff to inject new layers into an existing model train over (not recommended, it doesnt amount to anything)
|
|
||||||
"""
|
|
||||||
if True:
|
|
||||||
remapped_dict = {}
|
|
||||||
remapped_indices = [
|
|
||||||
(0, 1),
|
|
||||||
(1, 2),
|
|
||||||
(2, 3),
|
|
||||||
(3, 5),
|
|
||||||
(4, 6),
|
|
||||||
(5, 7),
|
|
||||||
(6, 9),
|
|
||||||
(7, 10),
|
|
||||||
(8, 11),
|
|
||||||
(9, 13),
|
|
||||||
(10, 14),
|
|
||||||
(11, 15),
|
|
||||||
]
|
|
||||||
|
|
||||||
for src, dst in remapped_indices:
|
|
||||||
remapped_dict[f"model.layers.{dst}.input_layernorm.weight"] = state[f"model.layers.{src}.input_layernorm.weight"]
|
|
||||||
remapped_dict[f"model.layers.{dst}.self_attn.k_proj.weight"] = state[f"model.layers.{src}.self_attn.k_proj.weight"]
|
|
||||||
remapped_dict[f"model.layers.{dst}.self_attn.q_proj.weight"] = state[f"model.layers.{src}.self_attn.q_proj.weight"]
|
|
||||||
remapped_dict[f"model.layers.{dst}.self_attn.v_proj.weight"] = state[f"model.layers.{src}.self_attn.v_proj.weight"]
|
|
||||||
remapped_dict[f"model.layers.{dst}.self_attn.o_proj.weight"] = state[f"model.layers.{src}.self_attn.o_proj.weight"]
|
|
||||||
remapped_dict[f"model.layers.{dst}.post_attention_layernorm.weight"] = state[f"model.layers.{src}.post_attention_layernorm.weight"]
|
|
||||||
remapped_dict[f"model.layers.{dst}.mlp.down_proj.weight"] = state[f"model.layers.{src}.mlp.down_proj.weight"]
|
|
||||||
remapped_dict[f"model.layers.{dst}.mlp.gate_proj.weight"] = state[f"model.layers.{src}.mlp.gate_proj.weight"]
|
|
||||||
remapped_dict[f"model.layers.{dst}.mlp.up_proj.weight"] = state[f"model.layers.{src}.mlp.up_proj.weight"]
|
|
||||||
|
|
||||||
del state[f"model.layers.{src}.input_layernorm.weight"]
|
|
||||||
del state[f"model.layers.{src}.self_attn.k_proj.weight"]
|
|
||||||
del state[f"model.layers.{src}.self_attn.q_proj.weight"]
|
|
||||||
del state[f"model.layers.{src}.self_attn.v_proj.weight"]
|
|
||||||
del state[f"model.layers.{src}.self_attn.o_proj.weight"]
|
|
||||||
del state[f"model.layers.{src}.post_attention_layernorm.weight"]
|
|
||||||
del state[f"model.layers.{src}.mlp.down_proj.weight"]
|
|
||||||
del state[f"model.layers.{src}.mlp.gate_proj.weight"]
|
|
||||||
del state[f"model.layers.{src}.mlp.up_proj.weight"]
|
|
||||||
|
|
||||||
for k, v in remapped_dict.items():
|
|
||||||
state[k] = v
|
|
||||||
"""
|
|
||||||
|
|
||||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||||
|
|
||||||
# load lora weights if exists
|
# load lora weights if exists
|
||||||
|
|
|
@ -162,6 +162,7 @@ class AR_NAR(Base):
|
||||||
quant_levels[i] = prom.shape[-1] - 1
|
quant_levels[i] = prom.shape[-1] - 1
|
||||||
|
|
||||||
# apply token dropout error compensation
|
# apply token dropout error compensation
|
||||||
|
"""
|
||||||
if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]):
|
if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]):
|
||||||
steps = resps.shape[0]
|
steps = resps.shape[0]
|
||||||
for l in range( quant_level ):
|
for l in range( quant_level ):
|
||||||
|
@ -171,6 +172,7 @@ class AR_NAR(Base):
|
||||||
if random.random() < token_dropout_error:
|
if random.random() < token_dropout_error:
|
||||||
offset = 1 * ( 1 if random.random() < 0.5 else -1 )
|
offset = 1 * ( 1 if random.random() < 0.5 else -1 )
|
||||||
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
|
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
|
||||||
|
"""
|
||||||
|
|
||||||
# only apply stop token for RVQ level 0
|
# only apply stop token for RVQ level 0
|
||||||
if (self.version < 7 and quant_level <= 0 and timesteps[i] is None) or (self.version >= 7 and timesteps[i] is None) or (self.predict_causally):
|
if (self.version < 7 and quant_level <= 0 and timesteps[i] is None) or (self.version >= 7 and timesteps[i] is None) or (self.predict_causally):
|
||||||
|
@ -1471,18 +1473,22 @@ def example_usage():
|
||||||
learning_rate = 0.01
|
learning_rate = 0.01
|
||||||
|
|
||||||
optimizer = ml.Apollo
|
optimizer = ml.Apollo
|
||||||
params["params"] = [{'params': params, 'rank': 1, 'proj': 'random', 'scale_type': 'tensor', 'scale': 128,'update_proj_gap': 200, 'proj_type': 'std'}]
|
params["params"] = [
|
||||||
|
{'params': params, 'rank': 1, 'proj': 'random', 'scale_type': 'tensor', 'scale': 128,'update_proj_gap': 200, 'proj_type': 'std'}
|
||||||
|
]
|
||||||
elif optimizer == "muon":
|
elif optimizer == "muon":
|
||||||
del params["params"]
|
|
||||||
optimizer = ml.Muon
|
optimizer = ml.Muon
|
||||||
|
|
||||||
params["muon_params"] = [ param for name, param in model.model.named_parameters() if param.ndim >= 2 ]
|
muon_params = [ param for name, param in model.model.named_parameters() if param.ndim >= 2 ]
|
||||||
params["adamw_params"] = [ param for name, param in model.model.named_parameters() if param.ndim < 2 ]
|
adamw_params = [ param for name, param in model.model.named_parameters() if param.ndim < 2 ]
|
||||||
params["adamw_params"] += [ param for name, param in model.named_parameters() if not name.startswith('model.') ]
|
adamw_params += [ param for name, param in model.named_parameters() if not name.startswith('model.') ]
|
||||||
|
|
||||||
if cfg.hyperparameters.optimizer_params is not None:
|
params["params"] = [
|
||||||
params["adamw_betas"] = cfg.hyperparameters.optimizer_params.pop("adamw_betas", (0.95, 0.95))
|
{ "params": muon_params, "muon": True },
|
||||||
params["adamw_eps"] = cfg.hyperparameters.optimizer_params.pop("adamw_eps", 1e-8)
|
{ "params": adamw_params, "muon": False, "betas": (0.95, 0.95), "eps": 1e-8 },
|
||||||
|
]
|
||||||
|
elif optimizer == "cosmos":
|
||||||
|
optimizer = ml.COSMOS
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unrecognized optimizer: {optimizer}")
|
raise ValueError(f"Unrecognized optimizer: {optimizer}")
|
||||||
|
|
||||||
|
|
|
@ -91,15 +91,8 @@ def _get_offsets():
|
||||||
"resps|NAR:0:0": (16677, 17702),
|
"resps|NAR:0:0": (16677, 17702),
|
||||||
}
|
}
|
||||||
|
|
||||||
def _dropout_mask( input, p=None ):
|
def _dropout_mask( input, p ):
|
||||||
# cosine scheduling
|
return (torch.rand(input.shape[0], device=input.device) < p)
|
||||||
if p is None:
|
|
||||||
t = random.random()
|
|
||||||
p = math.cos(t * math.pi * 0.5)
|
|
||||||
|
|
||||||
seq = [ random.random() < p for _ in range( input.shape[0] ) ]
|
|
||||||
mask = torch.tensor( seq, dtype=torch.bool, device=input.device )
|
|
||||||
return mask
|
|
||||||
|
|
||||||
def _create_mask(l, device):
|
def _create_mask(l, device):
|
||||||
"""1 is valid region and 0 is invalid."""
|
"""1 is valid region and 0 is invalid."""
|
||||||
|
@ -1383,13 +1376,14 @@ class Base(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
# apply token dropout
|
# apply token dropout
|
||||||
|
"""
|
||||||
if token_dropout_rate > 0.0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]):
|
if token_dropout_rate > 0.0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]):
|
||||||
steps = embedding.shape[0] - (1 if quant_level == 0 else 0) # do not mess with stop token
|
steps = embedding.shape[0] - (1 if quant_level == 0 else 0) # do not mess with stop token
|
||||||
for i in range( steps ):
|
for i in range( steps ):
|
||||||
if random.random() > token_dropout_rate:
|
if random.random() > token_dropout_rate:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
embedding[i] = self.dropout_token
|
embedding[i] = self.dropout_token
|
||||||
|
"""
|
||||||
elif name == "timestep" and self.time_emb is not None:
|
elif name == "timestep" and self.time_emb is not None:
|
||||||
embedding = self.time_emb( input )
|
embedding = self.time_emb( input )
|
||||||
elif name == "len" and self.len_emb is not None:
|
elif name == "len" and self.len_emb is not None:
|
||||||
|
|
|
@ -66,15 +66,14 @@ class Muon(torch.optim.Optimizer):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
params=None,
|
||||||
lr=1e-3,
|
lr=1e-3,
|
||||||
wd=0.1,
|
wd=0.1,
|
||||||
muon_params=None,
|
|
||||||
momentum=0.95,
|
momentum=0.95,
|
||||||
nesterov=True,
|
nesterov=True,
|
||||||
ns_steps=5,
|
ns_steps=5,
|
||||||
adamw_params=None,
|
betas=(0.95, 0.95),
|
||||||
adamw_betas=(0.95, 0.95),
|
eps=1e-8,
|
||||||
adamw_eps=1e-8,
|
|
||||||
):
|
):
|
||||||
|
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
|
@ -83,22 +82,12 @@ class Muon(torch.optim.Optimizer):
|
||||||
momentum=momentum,
|
momentum=momentum,
|
||||||
nesterov=nesterov,
|
nesterov=nesterov,
|
||||||
ns_steps=ns_steps,
|
ns_steps=ns_steps,
|
||||||
adamw_betas=adamw_betas,
|
betas=betas,
|
||||||
adamw_eps=adamw_eps,
|
eps=eps,
|
||||||
|
muon=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
params = list(muon_params)
|
|
||||||
adamw_params = list(adamw_params) if adamw_params is not None else []
|
|
||||||
params.extend(adamw_params)
|
|
||||||
super().__init__(params, defaults)
|
super().__init__(params, defaults)
|
||||||
# Sort parameters into those for which we will use Muon, and those for which we will not
|
|
||||||
for p in muon_params:
|
|
||||||
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
|
|
||||||
assert p.ndim == 2, p.ndim
|
|
||||||
self.state[p]["use_muon"] = True
|
|
||||||
for p in adamw_params:
|
|
||||||
# Do not use Muon for parameters in adamw_params
|
|
||||||
self.state[p]["use_muon"] = False
|
|
||||||
|
|
||||||
def adjust_lr_for_muon(self, lr, param_shape):
|
def adjust_lr_for_muon(self, lr, param_shape):
|
||||||
A, B = param_shape[:2]
|
A, B = param_shape[:2]
|
||||||
|
@ -125,19 +114,14 @@ class Muon(torch.optim.Optimizer):
|
||||||
############################
|
############################
|
||||||
# Muon #
|
# Muon #
|
||||||
############################
|
############################
|
||||||
|
if group["muon"]:
|
||||||
# this actually doesn't work with deepspeed for the same reason APOLLO required modifications:
|
|
||||||
# deepspeed's BF16/F16 optimizer wrapper modifies the tensors, so self.state loses the right mapping
|
|
||||||
# can't be assed to figure it out right now since it's not easy to fix like APOLLO
|
|
||||||
|
|
||||||
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
|
||||||
# import pdb; pdb.set_trace()
|
# import pdb; pdb.set_trace()
|
||||||
lr = group["lr"]
|
lr = group["lr"]
|
||||||
wd = group["wd"]
|
wd = group["wd"]
|
||||||
momentum = group["momentum"]
|
momentum = group["momentum"]
|
||||||
|
|
||||||
# generate weight updates in distributed fashion
|
# generate weight updates in distributed fashion
|
||||||
for p in params:
|
for p in group["params"]:
|
||||||
# sanity check
|
# sanity check
|
||||||
g = p.grad
|
g = p.grad
|
||||||
if g is None:
|
if g is None:
|
||||||
|
@ -170,14 +154,13 @@ class Muon(torch.optim.Optimizer):
|
||||||
############################
|
############################
|
||||||
# AdamW backup #
|
# AdamW backup #
|
||||||
############################
|
############################
|
||||||
|
else:
|
||||||
params = [p for p in group["params"] if not self.state[p]["use_muon"]]
|
|
||||||
lr = group['lr']
|
lr = group['lr']
|
||||||
beta1, beta2 = group["adamw_betas"]
|
beta1, beta2 = group["betas"]
|
||||||
eps = group["adamw_eps"]
|
eps = group["eps"]
|
||||||
weight_decay = group["wd"]
|
weight_decay = group["wd"]
|
||||||
|
|
||||||
for p in params:
|
for p in group["params"]:
|
||||||
g = p.grad
|
g = p.grad
|
||||||
if g is None:
|
if g is None:
|
||||||
continue
|
continue
|
||||||
|
|
Loading…
Reference in New Issue
Block a user