from contextlib import contextmanager import math import torch import torch.nn.functional as F from ..config import cfg Embedding = torch.nn.Embedding Linear = torch.nn.Linear Adam = torch.optim.Adam AdamW = torch.optim.AdamW SGD = torch.optim.SGD Adagrad = torch.optim.Adagrad # https://github.com/kyegomez/BitNet if cfg.optimizations.bitnet: from bitnet import BitLinear if cfg.optimizations.bitsandbytes: import bitsandbytes as bnb if cfg.optimizations.linear: if cfg.optimizations.bitnet: Linear = BitLinear else: Linear = bnb.nn.Linear8bitLt if cfg.optimizations.embedding: Embedding = bnb.nn.modules.Embedding """ Embedding.forward = lambda self, input: ( self.norm(F.embedding( input, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse, )).to(self.weight.dtype) ) """ if cfg.optimizations.optimizers: Adam = bnb.optim.Adam8bit AdamW = bnb.optim.AdamW8bit SGD = bnb.optim.SGD8bit Adagrad = bnb.optim.Adagrad8bit elif cfg.optimizations.dadaptation: import dadaptation if cfg.optimizations.optimizers: Adam = dadaptation.DAdaptAdam AdamW = dadaptation.DAdaptAdam SGD = dadaptation.DAdaptSGD AdaGrad = dadaptation.DAdaptAdaGrad # handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16) @contextmanager def autocast(input, from_dtype, to_dtype): if input.dtype == from_dtype: input = input.to(to_dtype) yield input input = input.to(from_dtype) else: yield input @contextmanager def autocasts(input, from_dtype, to_dtype): if input.dtype in from_dtype: from_dtype = input.dtype input = input.to(to_dtype) yield input input = input.to(from_dtype) else: yield input # handles temporarily upcasting 'index tensors' so torch will stop bitching def autocast_forward( func ): def wrapper( self, input, *args, **kwargs ): with autocasts( input, [torch.int16, torch.int8, torch.uint8, torch.float16, torch.bfloat16], torch.int32 ) as k: return func( self, k, *args, **kwargs ) return wrapper Embedding.forward = autocast_forward(Embedding.forward) if cfg.optimizations.fp8: import transformer_engine.pytorch as te Linear = te.Linear @contextmanager def autocast(): yield te.fp8_autocast(enabled=True) else: @contextmanager def autocast(): yield torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp) if cfg.optimizations.injects: if cfg.optimizations.linear: torch.nn.Linear = Linear if cfg.optimizations.embedding: torch.nn.Embedding = Embedding if cfg.optimizations.optimizers: torch.optim.Adam = Adam torch.optim.AdamW = AdamW torch.optim.SGD = SGD # disgusting kludge, but it works (just realized BitNet has its own replacement routine) # generalizing this would be super sugoi but the there's no catch all for arguments def replace_linear( model, klass=Linear, target=torch.nn.Linear, verbose=False ): bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet device = next(model.parameters()).device dtype = next(model.parameters()).dtype modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)] for *parent, k in modules: name = '.'.join(parent) m = getattr( model.get_submodule(name), k ) if isinstance(m, klass): continue kwargs = dict( in_features = m.in_features, out_features = m.out_features, bias = m.bias is not None, ) if not bnb else dict( input_features=m.in_features, output_features=m.out_features, bias=m.bias is not None, ) # overwrite setattr( model.get_submodule(name), k, klass( **kwargs ).to(device=device, dtype=dtype) ) if verbose: print(f"Replacing {name}.{k} to", klass) return model def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbose=False ): device = next(model.parameters()).device dtype = next(model.parameters()).dtype modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)] for *parent, k in modules: name = '.'.join(parent) m = getattr( model.get_submodule(name), k ) if isinstance(m, klass): continue kwargs = dict( num_embeddings=m.num_embeddings, embedding_dim=m.embedding_dim, padding_idx=m.padding_idx, max_norm=m.max_norm, norm_type=m.norm_type, scale_grad_by_freq=m.scale_grad_by_freq, sparse=m.sparse, ) # overwrite setattr( model.get_submodule(name), k, klass( **kwargs ).to(device=device, dtype=dtype) ) if verbose: print(f"Replacing {name}.{k} to", klass) return model # cannot feasibly do default arguments here sad def replace_attention( model, klass, target, mode="math", verbose=False ): device = next(model.parameters()).device dtype = next(model.parameters()).dtype modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)] for *parent, k in modules: name = '.'.join(parent) m = getattr( model.get_submodule(name), k ) if isinstance(m, klass): continue kwargs = dict( config = m.config, layer_idx = m.layer_idx, mode = mode, ) # overwrite setattr( model.get_submodule(name), k, klass( **kwargs ).to(device=device, dtype=dtype) ) if verbose: print(f"Replacing {name}.{k} to", klass) return model # trim/expand a tensor (for example, in a state dict) def resize_weight( weight, target, dim=0, random=True ): # trim if target < weight.shape[dim]: return weight[:target] # expand if target > weight.shape[dim]: fn = torch.rand if random else torch.zeros return torch.stack( [ x for x in weight ] + [ fn( weight[0].shape ).to(device=weight[0].device, dtype=weight[0].dtype) for _ in range( target - weight.shape[dim] ) ] ) return weight # https://github.com/konstmish/prodigy try: from prodigyopt import Prodigy except Exception as e: print('Error while importing Prodigyopt:', str(e)) pass # https://github.com/facebookresearch/schedule_free/ try: import schedulefree except Exception as e: print('Error while importing Schedule_Free:', str(e)) pass