From 0d5d545a40392d39b1d37e4d02993203d52f0fbf Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 9 May 2024 20:28:20 -0500 Subject: [PATCH] crammed in DAdaptation (doesn't seem worth it) and ScheduleFree (forgot I wanted to weeks ago, seems promising), optimization wrapper cleanup, test trainer changes, etc. --- vall_e/config.py | 38 +++++++--- vall_e/data.py | 19 ++--- vall_e/emb/qnt.py | 2 +- vall_e/engines/__init__.py | 27 ++++++- vall_e/models/ar_nar.py | 66 ++++++++++++++--- vall_e/models/base.py | 78 ++++++-------------- vall_e/utils/wrapper.py | 142 +++++++++++++++++++++++++++++-------- 7 files changed, 256 insertions(+), 116 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index 38f775e..810d1d0 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -306,11 +306,14 @@ class Hyperparameters: optimizer: str = "Adamw" torch_optimizer: bool = False + optimizer_params: dict = field(default_factory=lambda: {}) learning_rate: float = 3.25e-4 - scheduler_type: str = "" + scheduler: str = "" + scheduler_type: str = "" # deprecated scheduler_params: dict = field(default_factory=lambda: {}) + torch_scheduler: bool = False @dataclass() class Evaluation: @@ -337,7 +340,7 @@ class DeepSpeed: for k in cfg.hyperparameters.scheduler_params: scheduler_params[k] = cfg.hyperparameters.scheduler_params[k] - if cfg.hyperparameters.scheduler_type == "WarmupDecayLR" and 'total_num_steps' not in scheduler_params: + if cfg.hyperparameters.scheduler == "WarmupDecayLR" and 'total_num_steps' not in scheduler_params: scheduler_params['total_num_steps'] = cfg.trainer.iterations ds_cfg = { @@ -350,9 +353,9 @@ class DeepSpeed: } } if not cfg.hyperparameters.torch_optimizer else None, "scheduler": { - "type": cfg.hyperparameters.scheduler_type, + "type": cfg.hyperparameters.scheduler, "params": scheduler_params, - } if cfg.hyperparameters.scheduler_type != "" else None, + } if not cfg.hyperparameters.torch_scheduler else None, "gradient_clipping": cfg.hyperparameters.gradient_clipping, "fp16": { "enabled": True, @@ -544,15 +547,17 @@ class Inference: # should be renamed to optimizations @dataclass() class Optimizations: - bitsandbytes: bool = False - injects: bool = False - replace: bool = False + injects: bool = False # overwrites default torch classes (not recommended) + replace: bool = False # replaces modules in place with the optimized version (recommended) - linear: bool = True - embedding: bool = True + linear: bool = True # inject/replace linear for BnB + embedding: bool = True # inject/replace embedding for BnB + optimizers: bool = True # inject/replace optimizers (BnB, DAdaptation) - bitnet: bool = False - fp8: bool = False + bitsandbytes: bool = False # use bitsandbytes + dadaptation: bool = True # use dadaptation optimizer + bitnet: bool = False # use bitnet + fp8: bool = False # use fp8 @dataclass() class Config(_Config): @@ -636,6 +641,17 @@ class Config(_Config): else: self.optimizations = Optimizations(**self.optimizations) + if self.hyperparameters.scheduler_type and not self.hyperparameters.scheduler: + self.hyperparameters.scheduler = self.hyperparameters.scheduler_type + self.hyperparameters.scheduler_type = "" + + # do not combine the two + if self.hyperparameters.scheduler == "schedulefree" and self.optimizations.dadaptation: + self.hyperparameters.scheduler = "" + + if self.hyperparameters.scheduler == "": + self.hyperparameters.torch_scheduler = True + # Preserves the old behavior class NaiveTokenizer: def get_vocab( self ): diff --git a/vall_e/data.py b/vall_e/data.py index f6e02c9..fa61289 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -379,6 +379,11 @@ class Dataset(_Dataset): path = random.choice(choices) if cfg.dataset.use_hdf5: key = _get_hdf5_path(path) + + if "audio" not in cfg.hdf5[key]: + _logger.warning("MISSING AUDIO:", key) + continue + qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16) else: qnt = _load_quants(path) @@ -763,15 +768,15 @@ def create_dataset_metadata( skip_existing=True ): name = str(dir) name = name.replace(root, "") - # yucky speaker_name = name - if "LbriTTS-R" in speaker_name: - speaker_name = speaker_name.replace("LbriTTS-R", "LibriVox") metadata_path = Path(f"{metadata_root}/{speaker_name}.json") metadata_path.parents[0].mkdir(parents=True, exist_ok=True) - metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read()) + try: + metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read()) + except Exception as e: + metadata = {} if not os.path.isdir(f'{root}/{name}/'): return @@ -872,8 +877,8 @@ def create_dataset_hdf5( skip_existing=True ): # yucky speaker_name = name - if "LbriTTS-R" in speaker_name: - speaker_name = speaker_name.replace("LbriTTS-R", "LibriVox") + if "LibriTTS-R" in speaker_name: + speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox") metadata_path = Path(f"{metadata_root}/{speaker_name}.json") metadata_path.parents[0].mkdir(parents=True, exist_ok=True) @@ -899,10 +904,8 @@ def create_dataset_hdf5( skip_existing=True ): key = f'{type}/{speaker_name}/{id}' - """ if skip_existing and key in hf: continue - """ group = hf.create_group(key) if key not in hf else hf[key] diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 8229a51..35cc8ab 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -143,7 +143,7 @@ def _load_vocos_model(device="cuda", levels=cfg.model.max_levels): @cache def _load_dac_model(device="cuda", levels=cfg.model.max_levels): - kwargs = dict(model_type="24khz",model_bitrate="8kbps",tag="latest") + kwargs = dict(model_type="44khz",model_bitrate="8kbps",tag="latest") """ if not cfg.variable_sample_rate: # yes there's a better way, something like f'{cfg.sample.rate//1000}hz' diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index faf927a..3f4e45a 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -45,11 +45,16 @@ def load_engines(training=True): if inferencing: model._cfg.training = False - if (cfg.optimizations.bitsandbytes and cfg.optimizations.replace) or (cfg.optimizations.fp8): + if cfg.optimizations.replace and cfg.optimizations.linear: model.model = ml.replace_linear( model.model ) + + if cfg.optimizations.replace and cfg.optimizations.embedding: + model.model = ml.replace_embedding( model.model ) if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer): optimizer_class = None + scheduler_class = None + params = { "lr": cfg.hyperparameters.learning_rate, } @@ -58,6 +63,10 @@ def load_engines(training=True): params["eps"] = 1e-07 params["weight_decay"] = 0.01 + # for dadaptation since it has Adam only + if ml.AdamW == ml.Adam: + params["decouple"] = True + optimizer_class = ml.AdamW elif cfg.hyperparameters.optimizer.lower() == "sgd": optimizer = ml.SGD @@ -72,11 +81,27 @@ def load_engines(training=True): raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}') params.update(cfg.hyperparameters.optimizer_params) + optimizer = optimizer_class( [ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ], **params, ) + if cfg.hyperparameters.scheduler.lower() == "schedulefree": + if cfg.hyperparameters.optimizer.lower() == "adamw": + scheduler_class = ml.schedulefree.AdamWScheduleFree + elif cfg.hyperparameters.optimizer.lower() == "sgd": + scheduler_class = ml.schedulefree.SGDScheduleFree + else: + raise ValueError(f'ScheduleFree not implemented with requested optimizer: {cfg.hyperparameters.optimizer}') + + optimizer = scheduler_class( + [ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ], + lr = params['lr'] + ) + + + # set up our LR scheduler here if inferencing: diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 3164b36..7cb24b0 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -365,7 +365,7 @@ def example_usage(): 'n_tokens': 1024, 'd_model': 1024, # 256, # 1024, # 1536 'n_heads': 16, # 4, # 16, # 24 - 'n_layers': 12, # 32 + 'n_layers': 4, # 32 'n_experts': 1, 'l_padding': 8 if cfg.optimizations.fp8 else 0, @@ -381,16 +381,66 @@ def example_usage(): """ model = AR_NAR(**kwargs).to(device) - steps = 100 - optimizer = ml.Prodigy(model.parameters(), lr=1.0) - #optimizer = ml.Adagrad(model.parameters(), lr=1.0e-2) - #optimizer = ml.AdamW(model.parameters(), lr=1.0e-4) + steps = 1000 + optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy" + scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else "" + learning_rate = cfg.hyperparameters.learning_rate if cfg.cfg_path is not None else None + + if cfg.optimizations.dadaptation: + # do not combine the two + if scheduler == "schedulefree": + scheduler = "" + + learning_rate = 1.0 + + if optimizer == "prodigy": + if learning_rate is None: + learning_rate = 1.0 + + optimizer = ml.Prodigy + elif optimizer == "adagrad": + if learning_rate is None: + learning_rate = 1.0e-2 + + optimizer = ml.Adagrad + elif optimizer == "adamw": + if learning_rate is None: + learning_rate = 1.0e-4 + + optimizer = ml.AdamW + elif optimizer == "sdg": + if learning_rate is None: + learning_rate = 1.0e-4 + + optimizer = ml.SGD + else: + raise ValueError(f"Unrecognized optimizer: {optimizer}") + + print("Optimizer:", optimizer, "\tLearning rate:", learning_rate) + + optimizer = optimizer(model.parameters(), lr=learning_rate) + + if scheduler == "schedulefree": + if isinstance(optimizer, ml.AdamW): + scheduler = ml.schedulefree.AdamWScheduleFree + elif isinstance(optimizer, ml.SGD): + scheduler = ml.schedulefree.SGDScheduleFree + else: + scheduler = None + + if scheduler is not None: + print("Scheduler:", scheduler) + optimizer = scheduler( model.parameters(), lr = learning_rate ) + + if cfg.optimizations.replace and cfg.optimizations.linear: + model = ml.replace_linear( model ) + + if cfg.optimizations.replace and cfg.optimizations.embedding: + model = ml.replace_embedding( model ) + engine = Engine(model=model, optimizer=optimizer) - if (cfg.optimizations.bitsandbytes and cfg.optimizations.replace) or (cfg.optimizations.fp8): - model.model = ml.replace_linear( model.model ) - torch.save( { 'module': model.state_dict() }, "./data/test.pth" ) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index c764557..2cd0d04 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -16,6 +16,8 @@ from torch.nn.utils.rnn import pad_sequence from torch.utils.checkpoint import checkpoint from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision +from ..utils import wrapper as ml + from ..samplers import reptition_penalize, length_penalize, top_k_top_p_filtering, dynamic_temperature, top_k_logits_list, mirostat_sample try: @@ -191,48 +193,9 @@ try: attn_output = self.o_proj(attn_output) return attn_output, attn_weights, past_key_value - - LLAMA_ATTENTIONS["xformers"] = LLamaXformersAttention - except Exception as e: print("Error creating `LLamaXformersAttention`:", e) -def replace_attention( model, impl, verbose=False ): - device = next(model.parameters()).device - dtype = next(model.parameters()).dtype - attentions = [k.split('.') for k, m in model.named_modules() if isinstance(m, LlamaAttention)] - - if impl not in LLAMA_ATTENTIONS: - print(f"Attention '{imp} is not in LLAMA_ATTENTIONS'") - return model - - klass = LLAMA_ATTENTIONS[impl] - - for *parent, k in attentions: - name = '.'.join(parent) - - # copy parameters - m = getattr( model.get_submodule(name), k ) - - if isinstance(m, klass): - continue - - config = m.config - layer_idx = m.layer_idx - - kwargs = dict(config=config, layer_idx=layer_idx) - - # overwrite - setattr( - model.get_submodule(name), k, - klass( **kwargs ).to(device=device, dtype=dtype) - ) - - if verbose: - print(f"Replacing {name}.{k} to", klass) - - return model - def _create_mask(l, device): """1 is valid region and 0 is invalid.""" seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t) @@ -485,6 +448,14 @@ class Base(nn.Module): self.sep = nn.Parameter(torch.randn(d_model)) + # ick, there has to be a better way + attention = self.config.attention if self.config is not None else None + use_xformers = False + + if attention == "xformers": + use_xformers = True + attention = None + if self.arch_type == "transformer": self.sin_emb = SinusoidalEmbedding(d_model) self.blocks = nn.ModuleList([TransformerBlock( @@ -495,7 +466,7 @@ class Base(nn.Module): norm_type=self.norm_type, n_levels=self.n_resp_levels, ) for _ in range(n_layers) ]) - elif self.arch_type == "mistral" or self.arch_type == "mixtral": + elif self.arch_type in ["mistral", "mixtral"]: if n_experts <= 1: self.model = MistralModel(MistralConfig( vocab_size=n_resp_tokens, @@ -509,7 +480,7 @@ class Base(nn.Module): hidden_act="gelu", is_encoder_decoder=False, is_decoder=True, - attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2", + attn_implementation=attention, )) else: self.model = MixtralModel(MixtralConfig( @@ -528,18 +499,10 @@ class Base(nn.Module): is_decoder=True, num_local_experts=n_experts, num_experts_per_tok=min(2, n_experts), - attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2", + attn_implementation=attention, )) elif self.arch_type == "llama": if n_experts <= 1: - # ick, there has to be a better way - attention = self.config.attention if self.config is not None else None # "flash_attention_2", - use_xformers = False - - if attention == "xformers": - use_xformers = True - attention = None - self.model = LlamaModel(LlamaConfig( vocab_size=n_resp_tokens, hidden_size=d_model, @@ -555,9 +518,6 @@ class Base(nn.Module): is_decoder=True, attn_implementation=attention, )) - - if use_xformers: - self.model = replace_attention( self.model, "xformers" if use_xformers else attention ) else: self.model = MixtralModel(MixtralConfig( vocab_size =n_resp_tokens, @@ -575,9 +535,8 @@ class Base(nn.Module): is_decoder=True, num_local_experts=n_experts, num_experts_per_tok=min(2, n_experts), - attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2", + attn_implementation=attention, )) - elif self.arch_type == "retnet": kwargs = dict( vocab_size=n_resp_tokens, @@ -589,9 +548,9 @@ class Base(nn.Module): dropout=p_dropout if training else 0.0, checkpoint_activations=self.activation_checkpointing, activation_fn="gelu", - use_layernorm=True, # self.version < 3, - use_biases=True, # self.version < 3, - use_glu=False, # self.version >= 3, + use_layernorm=self.version < 3, + use_biases=self.version < 3, + use_glu=self.version >= 3, chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0, recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0, @@ -642,6 +601,9 @@ class Base(nn.Module): else: raise RuntimeError(f'Unknown arch specified: {self.arch_type}') + if use_xformers: + self.model = ml.replace_attention( self.model, klass=LLamaXformersAttention, target=LlamaAttention ) + self.classifier = nn.Linear(d_model, n_resp_tokens) self.accuracy_metric = MulticlassAccuracy( diff --git a/vall_e/utils/wrapper.py b/vall_e/utils/wrapper.py index 76c409c..fcd1275 100755 --- a/vall_e/utils/wrapper.py +++ b/vall_e/utils/wrapper.py @@ -9,6 +9,11 @@ from ..config import cfg Embedding = torch.nn.Embedding Linear = torch.nn.Linear +Adam = torch.optim.Adam +AdamW = torch.optim.AdamW +SGD = torch.optim.SGD +Adagrad = torch.optim.Adagrad + # https://github.com/kyegomez/BitNet if cfg.optimizations.bitnet: from bitnet import BitLinear @@ -37,19 +42,20 @@ if cfg.optimizations.bitsandbytes: )).to(self.weight.dtype) ) """ + if cfg.optimizations.optimizers: + Adam = bnb.optim.Adam8bit + AdamW = bnb.optim.AdamW8bit + SGD = bnb.optim.SGD8bit + Adagrad = bnb.optim.Adagrad8bit -if cfg.optimizations.bitsandbytes: - import bitsandbytes as bnb +elif cfg.optimizations.dadaptation: + import dadaptation - Adam = bnb.optim.Adam8bit - AdamW = bnb.optim.AdamW8bit - SGD = bnb.optim.SGD8bit - Adagrad = bnb.optim.Adagrad8bit -else: - Adam = torch.optim.Adam - AdamW = torch.optim.AdamW - SGD = torch.optim.SGD - Adagrad = torch.optim.Adagrad + if cfg.optimizations.optimizers: + Adam = dadaptation.DAdaptAdam + AdamW = dadaptation.DAdaptAdam + SGD = dadaptation.DAdaptSGD + AdaGrad = dadaptation.DAdaptAdaGrad # handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16) @contextmanager @@ -92,42 +98,112 @@ else: def autocast(): yield torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp) -if cfg.optimizations.injects and cfg.optimizations.bitsandbytes: - torch.nn.Linear = Linear - torch.nn.Embedding = Embedding +if cfg.optimizations.injects: + if cfg.optimizations.linear: + torch.nn.Linear = Linear + + if cfg.optimizations.embedding: + torch.nn.Embedding = Embedding - torch.optim.Adam = Adam - torch.optim.AdamW = AdamW - torch.optim.SGD = SGD + if cfg.optimizations.optimizers: + torch.optim.Adam = Adam + torch.optim.AdamW = AdamW + torch.optim.SGD = SGD # disgusting kludge, but it works (just realized BitNet has its own replacement routine) -def replace_linear( model, verbose=False ): +# generalizing this would be super sugoi but the there's no catch all for arguments +def replace_linear( model, klass=Linear, target=torch.nn.Linear, verbose=False ): bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet device = next(model.parameters()).device - linears = [k.split('.') for k, m in model.named_modules() if isinstance(m, torch.nn.Linear)] - klass = Linear + dtype = next(model.parameters()).dtype + modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)] - for *parent, k in linears: + for *parent, k in modules: name = '.'.join(parent) - - # copy parameters m = getattr( model.get_submodule(name), k ) if isinstance(m, klass): continue - in_features = m.in_features - out_features = m.out_features - bias = m.bias is not None - - kwargs = dict(in_features=in_features, out_features=out_features, bias=bias) if not bnb else dict(input_features=in_features, output_features=out_features, bias=bias) + kwargs = dict( + in_features = m.in_features, + out_features = m.out_features, + bias = m.bias is not None, + ) if not bnb else dict( + input_features=m.in_features, + output_features=m.out_features, + bias=m.bias is not None, + ) # overwrite setattr( model.get_submodule(name), k, - klass( **kwargs ).to(device=device, dtype=cfg.trainer.dtype) + klass( **kwargs ).to(device=device, dtype=dtype) + ) + + if verbose: + print(f"Replacing {name}.{k} to", klass) + + return model + +def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbose=False ): + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)] + + for *parent, k in modules: + name = '.'.join(parent) + + m = getattr( model.get_submodule(name), k ) + + if isinstance(m, klass): + continue + + kwargs = dict( + num_embeddings=m.num_embeddings, + embedding_dim=m.embedding_dim, + padding_idx=m.padding_idx, + max_norm=m.max_norm, + norm_type=m.norm_type, + scale_grad_by_freq=m.scale_grad_by_freq, + sparse=m.sparse, + ) + + # overwrite + setattr( + model.get_submodule(name), k, + klass( **kwargs ).to(device=device, dtype=dtype) + ) + + if verbose: + print(f"Replacing {name}.{k} to", klass) + + return model + +# cannot feasibly do default arguments here sad +def replace_attention( model, klass, target, verbose=False ): + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)] + + for *parent, k in modules: + name = '.'.join(parent) + + m = getattr( model.get_submodule(name), k ) + + if isinstance(m, klass): + continue + + kwargs = dict( + config = m.config, + layer_idx = m.layer_idx, + ) + # overwrite + setattr( + model.get_submodule(name), k, + klass( **kwargs ).to(device=device, dtype=dtype) ) if verbose: @@ -139,4 +215,12 @@ def replace_linear( model, verbose=False ): try: from prodigyopt import Prodigy except Exception as e: + print('Error while importing Prodigyopt:', str(e)) + pass + +# https://github.com/facebookresearch/schedule_free/ +try: + import schedulefree +except Exception as e: + print('Error while importing Schedule_Free:', str(e)) pass \ No newline at end of file