From 94861677d3b0c41591b08d523a319380c8a88c9f Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 28 Feb 2025 22:07:57 -0600 Subject: [PATCH] the learning rate scheduler pill is a tough pill to swallow --- vall_e/config.py | 5 ----- vall_e/engines/__init__.py | 28 ++++++++++++++++++++++--- vall_e/engines/base.py | 31 ++++++++++++++------------- vall_e/engines/deepspeed.py | 9 +++----- vall_e/models/ar_nar_v2.py | 29 +++++++++++++++++++++++-- vall_e/utils/ml.py | 42 ++++++++++++++++++++++++++++++++++++- 6 files changed, 112 insertions(+), 32 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index 97e31a9..6248cf8 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -504,7 +504,6 @@ class Hyperparameters: warmup_steps: int = 0 # number of steps to warm up the optimizer before performing updates, I think, this is just passed to deepspeed scheduler: str = "" # scheduler to use, currently don't ever use one so this doesn't really matter - scheduler_type: str = "" # deprecated scheduler_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config autotune: bool = False # to do deepspeed's autotuning @@ -1063,10 +1062,6 @@ class Config(BaseConfig): if model.training: model.teacher = False - if self.hyperparameters.scheduler_type and not self.hyperparameters.scheduler: - self.hyperparameters.scheduler = self.hyperparameters.scheduler_type - self.hyperparameters.scheduler_type = "" - # do not combine the two if self.hyperparameters.scheduler == "schedulefree" and self.optimizations.dadaptation: self.hyperparameters.scheduler = "" diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 144e474..150b717 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -175,10 +175,32 @@ def load_engines(training=True, **model_kwargs): lr = params['lr'], warmup_steps = cfg.hyperparameters.warmup_steps ) + elif cfg.hyperparameters.scheduler: + scheduler_kwargs = {} + if cfg.hyperparameters.scheduler.lower() == "onecycle": + scheduler_class = ml.OneCycleLR + scheduler_kwargs["max_lr"] = params['lr'] + elif cfg.hyperparameters.scheduler.lower() == "cosineannealing": + scheduler_class = ml.CosineAnnealingLR + elif cfg.hyperparameters.scheduler.lower() == "noam": + scheduler_class = ml.NoamLR + scheduler_kwargs["d_model"] = model.d_model + scheduler_kwargs["warmup_steps"] = cfg.hyperparameters.warmup_steps + elif cfg.hyperparameters.scheduler.lower() == "warmup": + scheduler_class = ml.WarmupLR + scheduler_kwargs["warmup_steps"] = cfg.hyperparameters.warmup_steps + else: + raise ValueError(f'Scheduler specified not implemented: {cfg.hyperparameters.scheduler}') + + scheduler_kwargs.update(cfg.hyperparameters.scheduler_params) + lr_scheduler = scheduler_class( + optimizer, + **scheduler_kwargs, + ) + """ + # set up our LR scheduler here + """ - """ - # set up our LR scheduler here - """ if inferencing: optimizer = None diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 0f3f8c8..7b05385 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -64,21 +64,20 @@ class Engine(): kwargs.pop("hyper_config") self.module = kwargs['model'].to(cfg.device).to(torch.float32 if cfg.trainer.amp else cfg.trainer.dtype) - self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None - self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None - - stats = kwargs.pop("stats", {}) - if stats is None: - stats = {} - self.global_steps = stats.pop("global_step", 0) - self.micro_steps = stats.pop("micro_step", 0) - self.global_samples = stats.pop("global_samples", 0) - self.tokens_processed = stats.pop("tokens_processed", 0) - - self._frozen_params = set() - + self.optimizer = kwargs.get('optimizer', None) + self.lr_scheduler = kwargs.get('lr_scheduler', None) self.loss_scaler = torch.cuda.amp.GradScaler() if cfg.trainer.scale_loss else None + stats = kwargs.get("stats", {}) + if stats is None: + stats = {} + + self.global_steps = stats.get("global_step", 0) + self.micro_steps = stats.get("micro_step", 0) + self.global_samples = stats.get("global_samples", 0) + self.tokens_processed = stats.get("tokens_processed", 0) + + self._frozen_params = set() self.current_batch_size = 0 self._global_grad_norm = None @@ -256,9 +255,11 @@ class Engine(): self.loss_scaler.update() else: self.optimizer.step() + + if self.lr_scheduler is not None: + self.lr_scheduler.step() + self.optimizer.zero_grad() - - # self._get_grad_norm() # doesn't actually work def _get_grad_norm(self): diff --git a/vall_e/engines/deepspeed.py b/vall_e/engines/deepspeed.py index de79fd6..f07705b 100755 --- a/vall_e/engines/deepspeed.py +++ b/vall_e/engines/deepspeed.py @@ -34,10 +34,7 @@ if not distributed_initialized() and cfg.trainer.backend == "deepspeed": class Engine(DeepSpeedEngine): def __init__(self, *args, **kwargs): - self.hyper_config = None - if 'hyper_config' in kwargs: - self.hyper_config = kwargs['hyper_config'] - kwargs.pop("hyper_config") + self.hyper_config = kwargs.pop('hyper_config', None) kwargs['config'] = cfg.trainer.deepspeed.ds_cfg kwargs['config_class'] = DeepSpeedConfig(kwargs['config']) @@ -50,18 +47,18 @@ class Engine(DeepSpeedEngine): } # kwargs['stats'] = None will return None when popped - maybe_stats = kwargs.pop('stats', stats) + maybe_stats = kwargs.get('stats', stats) if maybe_stats is not None: stats = maybe_stats super().__init__(None, *args, **kwargs) - self._frozen_params = set() self.global_steps = stats["global_step"] self.micro_steps = stats["micro_step"] self.global_samples = stats["global_samples"] self.tokens_processed = stats["tokens_processed"] + self._frozen_params = set() self.current_batch_size = 0 def freeze(self, freeze_all=True): diff --git a/vall_e/models/ar_nar_v2.py b/vall_e/models/ar_nar_v2.py index d8cd271..ab6b44d 100644 --- a/vall_e/models/ar_nar_v2.py +++ b/vall_e/models/ar_nar_v2.py @@ -952,6 +952,31 @@ def example_usage(): if scheduler is not None: _logger.info(f"Scheduler: {scheduler}") optimizer = scheduler( model.parameters(), lr = learning_rate ) + elif cfg.hyperparameters.scheduler: + scheduler_kwargs = {} + if scheduler == "onecycle": + scheduler_class = ml.OneCycleLR + scheduler_kwargs["max_lr"] = params['lr'] + elif scheduler == "cosineannealing": + scheduler_class = ml.CosineAnnealingLR + elif scheduler == "noam": + scheduler_class = ml.NoamLR + scheduler_kwargs["d_model"] = model.d_model + scheduler_kwargs["warmup_steps"] = cfg.hyperparameters.warmup_steps + elif scheduler == "warmup": + scheduler_class = ml.WarmupLR + scheduler_kwargs["warmup_steps"] = cfg.hyperparameters.warmup_steps + else: + raise ValueError(f'Scheduler specified not implemented: {cfg.hyperparameters.scheduler}') + + scheduler_kwargs.update(cfg.hyperparameters.scheduler_params) + scheduler = scheduler_class( + optimizer, + **scheduler_kwargs, + ) + + if isinstance(scheduler, str): + scheduler = None if cfg.optimizations.replace and cfg.optimizations.linear: model = ml.replace_linear( model ) @@ -968,7 +993,7 @@ def example_usage(): } """ - engine = Engine(model=model, optimizer=optimizer) + engine = Engine(model=model, optimizer=optimizer, lr_scheduler=scheduler) engines = Engines({"ar+nar": engine}) engines.setup() @@ -1047,7 +1072,7 @@ def example_usage(): for i in t: texts, proms, resps, tasks = sample_data() - stats = {"step": i} + stats = {"step": i, "lr": engine.get_lr()[0]} with torch.autograd.set_detect_anomaly(cfg.trainer.detect_grad_anomaly): stats |= engine.traverse(phns_list=texts, proms_list=proms, resps_list=resps, task_list=tasks, training=True) stats |= {"grad_norm": engine.get_global_grad_norm()} diff --git a/vall_e/utils/ml.py b/vall_e/utils/ml.py index 1b05feb..e375edd 100755 --- a/vall_e/utils/ml.py +++ b/vall_e/utils/ml.py @@ -1,9 +1,11 @@ from contextlib import contextmanager import math +import logging + import torch import torch.nn.functional as F -import logging +from torch.optim.lr_scheduler import _LRScheduler from ..config import cfg @@ -17,6 +19,44 @@ AdamW = torch.optim.AdamW SGD = torch.optim.SGD Adagrad = torch.optim.Adagrad +OneCycleLR = torch.optim.lr_scheduler.OneCycleLR +CosineAnnealingLR = torch.optim.lr_scheduler.CosineAnnealingLR +LambdaLR = torch.optim.lr_scheduler.LambdaLR + +# implements Noam scheduling +# it's cringe +class NoamLR(_LRScheduler): + def __init__(self, optimizer, warmup_steps, d_model=1024, last_epoch=-1): + self.base_factor = d_model ** (-0.5) + self.warmup_steps = warmup_steps + self.decay_factor = decay_factor + + super().__init__(optimizer, last_epoch) + + def get_lr(self): + step = max(1, self.last_epoch) + scale = self.base_factor * min(step ** (-0.5), step * self.warmup_steps ** (-1.5)) + + return [base_lr * scale for base_lr in self.base_lrs] + +# gradually warms up LR then holds or decays +class WarmupLR(_LRScheduler): + def __init__(self, optimizer, warmup_steps, decay_factor=0.0, last_epoch=-1): + self.warmup_steps = warmup_steps + self.decay_factor = decay_factor + + super().__init__(optimizer, last_epoch) + + def get_lr(self): + step = self.last_epoch + 1 + scale = 1 + if step < self.warmup_steps: + scale = float(step) / float(max(1, self.warmup_steps)) + elif self.decay_factor != 0: + scale = (1.0 - self.decay_factor) ** (step - self.warmup_steps) + + return [base_lr * scale for base_lr in self.base_lrs] + # https://github.com/kyegomez/BitNet if cfg.optimizations.bitnet: from bitnet import BitLinear