added support for optional prodigy optimizer (https://github.com/konstmish/prodigy) although it consumes a lot more VRAM per parameter

This commit is contained in:
mrq 2023-09-06 20:33:16 -05:00
parent 7ce06432fd
commit 712808494f
3 changed files with 31 additions and 11 deletions

View File

@ -1,5 +1,6 @@
from ..config import cfg
from .base import Base, list_to_tensor, Categorical
from ..utils import wrapper as ml
from ..config import cfg
import torch
from torch.nn.utils.rnn import pad_sequence
@ -169,7 +170,7 @@ def example_usage():
from einops import repeat
from ..emb.qnt import decode_to_file
from ..emb.qnt import decode_to_file, unload_model
from ..engines import Engine
from tqdm import tqdm
@ -201,9 +202,9 @@ def example_usage():
kwargs = {
'n_tokens': 1024,
'd_model': 1024,
'n_heads': 16,
'n_layers': 24,
'd_model': 1536, # 1536
'n_heads': 24, # 24
'n_layers': 24, # 32
}
"""
@ -214,7 +215,9 @@ def example_usage():
"""
model = AR_NAR(**kwargs).to(device)
engine = Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=0.001))
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
#optimizer = ml.AdamW(model.parameters(), lr=0.0001)
engine = Engine(model=model, optimizer=optimizer)
def sample( name, steps=600 ):
engine.eval()
@ -229,9 +232,11 @@ def example_usage():
for i, o in enumerate(resps_list):
_ = decode_to_file(o, f"data/ar+nar.{i}.{name}.wav", device=device)
unload_model()
def train():
engine.train()
t = trange(1000)
t = trange(500)
for i in t:
stats = {"step": i}
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)

View File

@ -84,6 +84,15 @@ def load_engines(invert=False):
model.parameters(),
**params,
)
elif (cfg.trainer.backend == "local" and cfg.hyperparameters.optimizer.lower() == "prodigy") or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.optimizer.lower() == "prodigy-torch"):
params = {
"lr": cfg.hyperparameters.learning_rate,
}
params.update(cfg.hyperparameters.optimizer_params)
optimizer = ml.Prodigy(
model.parameters(),
**params,
)
if not model._cfg.training:
optimizer = None

View File

@ -29,9 +29,9 @@ if cfg.bitsandbytes.enabled:
if cfg.bitsandbytes.enabled:
import bitsandbytes as bnb
Adam = bnb.optim.Adam
AdamW = bnb.optim.AdamW
SGD = bnb.optim.SGD
Adam = bnb.optim.Adam8bit
AdamW = bnb.optim.AdamW8bit
SGD = bnb.optim.SGD8bit
else:
Adam = torch.optim.Adam
AdamW = torch.optim.AdamW
@ -76,4 +76,10 @@ if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled:
torch.optim.Adam = Adam
torch.optim.AdamW = AdamW
torch.optim.SGD = SGD
torch.optim.SGD = SGD
# https://github.com/konstmish/prodigy
try:
from prodigyopt import Prodigy
except Exception as e:
pass