From 712808494faec676efca4800696bca1dc6f27d28 Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 6 Sep 2023 20:33:16 -0500 Subject: [PATCH] added support for optional prodigy optimizer (https://github.com/konstmish/prodigy) although it consumes a lot more VRAM per parameter --- vall_e/models/ar_nar.py | 19 ++++++++++++------- vall_e/utils/trainer.py | 9 +++++++++ vall_e/utils/wrapper.py | 14 ++++++++++---- 3 files changed, 31 insertions(+), 11 deletions(-) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index ccba99c..98afea4 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -1,5 +1,6 @@ -from ..config import cfg from .base import Base, list_to_tensor, Categorical +from ..utils import wrapper as ml +from ..config import cfg import torch from torch.nn.utils.rnn import pad_sequence @@ -169,7 +170,7 @@ def example_usage(): from einops import repeat - from ..emb.qnt import decode_to_file + from ..emb.qnt import decode_to_file, unload_model from ..engines import Engine from tqdm import tqdm @@ -201,9 +202,9 @@ def example_usage(): kwargs = { 'n_tokens': 1024, - 'd_model': 1024, - 'n_heads': 16, - 'n_layers': 24, + 'd_model': 1536, # 1536 + 'n_heads': 24, # 24 + 'n_layers': 24, # 32 } """ @@ -214,7 +215,9 @@ def example_usage(): """ model = AR_NAR(**kwargs).to(device) - engine = Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=0.001)) + optimizer = ml.Prodigy(model.parameters(), lr=1.0) + #optimizer = ml.AdamW(model.parameters(), lr=0.0001) + engine = Engine(model=model, optimizer=optimizer) def sample( name, steps=600 ): engine.eval() @@ -229,9 +232,11 @@ def example_usage(): for i, o in enumerate(resps_list): _ = decode_to_file(o, f"data/ar+nar.{i}.{name}.wav", device=device) + unload_model() + def train(): engine.train() - t = trange(1000) + t = trange(500) for i in t: stats = {"step": i} stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list) diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index 8c9a945..e729cc1 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -84,6 +84,15 @@ def load_engines(invert=False): model.parameters(), **params, ) + elif (cfg.trainer.backend == "local" and cfg.hyperparameters.optimizer.lower() == "prodigy") or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.optimizer.lower() == "prodigy-torch"): + params = { + "lr": cfg.hyperparameters.learning_rate, + } + params.update(cfg.hyperparameters.optimizer_params) + optimizer = ml.Prodigy( + model.parameters(), + **params, + ) if not model._cfg.training: optimizer = None diff --git a/vall_e/utils/wrapper.py b/vall_e/utils/wrapper.py index bbdbf8a..dc16236 100755 --- a/vall_e/utils/wrapper.py +++ b/vall_e/utils/wrapper.py @@ -29,9 +29,9 @@ if cfg.bitsandbytes.enabled: if cfg.bitsandbytes.enabled: import bitsandbytes as bnb - Adam = bnb.optim.Adam - AdamW = bnb.optim.AdamW - SGD = bnb.optim.SGD + Adam = bnb.optim.Adam8bit + AdamW = bnb.optim.AdamW8bit + SGD = bnb.optim.SGD8bit else: Adam = torch.optim.Adam AdamW = torch.optim.AdamW @@ -76,4 +76,10 @@ if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled: torch.optim.Adam = Adam torch.optim.AdamW = AdamW - torch.optim.SGD = SGD \ No newline at end of file + torch.optim.SGD = SGD + +# https://github.com/konstmish/prodigy +try: + from prodigyopt import Prodigy +except Exception as e: + pass \ No newline at end of file