vall-e/vall_e/utils/wrapper.py

242 lines
6.2 KiB
Python
Raw Normal View History

2023-08-02 21:53:35 +00:00
from contextlib import contextmanager
import math
2023-08-02 21:53:35 +00:00
import torch
import torch.nn.functional as F
2023-08-02 23:36:26 +00:00
from ..config import cfg
2023-08-02 21:53:35 +00:00
Embedding = torch.nn.Embedding
Linear = torch.nn.Linear
Adam = torch.optim.Adam
AdamW = torch.optim.AdamW
SGD = torch.optim.SGD
Adagrad = torch.optim.Adagrad
# https://github.com/kyegomez/BitNet
if cfg.optimizations.bitnet:
from bitnet import BitLinear
if cfg.optimizations.bitsandbytes:
2023-08-02 21:53:35 +00:00
import bitsandbytes as bnb
if cfg.optimizations.linear:
if cfg.optimizations.bitnet:
Linear = BitLinear
else:
Linear = bnb.nn.Linear8bitLt
2023-08-02 21:53:35 +00:00
if cfg.optimizations.embedding:
Embedding = bnb.nn.modules.Embedding
"""
2023-08-02 21:53:35 +00:00
Embedding.forward = lambda self, input: ( self.norm(F.embedding(
input,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)).to(self.weight.dtype) )
"""
2023-08-02 21:53:35 +00:00
if cfg.optimizations.optimizers:
Adam = bnb.optim.Adam8bit
AdamW = bnb.optim.AdamW8bit
SGD = bnb.optim.SGD8bit
Adagrad = bnb.optim.Adagrad8bit
2023-08-02 21:53:35 +00:00
elif cfg.optimizations.dadaptation:
import dadaptation
2023-08-02 21:53:35 +00:00
if cfg.optimizations.optimizers:
Adam = dadaptation.DAdaptAdam
AdamW = dadaptation.DAdaptAdam
SGD = dadaptation.DAdaptSGD
AdaGrad = dadaptation.DAdaptAdaGrad
2023-08-02 21:53:35 +00:00
# handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16)
@contextmanager
def autocast(input, from_dtype, to_dtype):
if input.dtype == from_dtype:
input = input.to(to_dtype)
yield input
input = input.to(from_dtype)
else:
2023-08-02 23:36:26 +00:00
yield input
@contextmanager
def autocasts(input, from_dtype, to_dtype):
if input.dtype in from_dtype:
from_dtype = input.dtype
input = input.to(to_dtype)
yield input
input = input.to(from_dtype)
else:
yield input
# handles temporarily upcasting 'index tensors' so torch will stop bitching
def autocast_forward( func ):
def wrapper( self, input, *args, **kwargs ):
with autocasts( input, [torch.int16, torch.int8, torch.uint8, torch.float16, torch.bfloat16], torch.int32 ) as k:
2023-08-02 23:36:26 +00:00
return func( self, k, *args, **kwargs )
return wrapper
Embedding.forward = autocast_forward(Embedding.forward)
if cfg.optimizations.fp8:
import transformer_engine.pytorch as te
Linear = te.Linear
@contextmanager
def autocast():
yield te.fp8_autocast(enabled=True)
else:
@contextmanager
def autocast():
yield torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp)
if cfg.optimizations.injects:
if cfg.optimizations.linear:
torch.nn.Linear = Linear
if cfg.optimizations.embedding:
torch.nn.Embedding = Embedding
2023-08-02 23:36:26 +00:00
if cfg.optimizations.optimizers:
torch.optim.Adam = Adam
torch.optim.AdamW = AdamW
torch.optim.SGD = SGD
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
# generalizing this would be super sugoi but the there's no catch all for arguments
def replace_linear( model, klass=Linear, target=torch.nn.Linear, verbose=False ):
bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet
2024-04-16 15:19:02 +00:00
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
for *parent, k in modules:
name = '.'.join(parent)
m = getattr( model.get_submodule(name), k )
if isinstance(m, klass):
continue
kwargs = dict(
in_features = m.in_features,
out_features = m.out_features,
bias = m.bias is not None,
) if not bnb else dict(
input_features=m.in_features,
output_features=m.out_features,
bias=m.bias is not None,
)
# overwrite
setattr(
model.get_submodule(name), k,
klass( **kwargs ).to(device=device, dtype=dtype)
)
if verbose:
print(f"Replacing {name}.{k} to", klass)
return model
def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbose=False ):
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
for *parent, k in modules:
name = '.'.join(parent)
m = getattr( model.get_submodule(name), k )
if isinstance(m, klass):
continue
kwargs = dict(
num_embeddings=m.num_embeddings,
embedding_dim=m.embedding_dim,
padding_idx=m.padding_idx,
max_norm=m.max_norm,
norm_type=m.norm_type,
scale_grad_by_freq=m.scale_grad_by_freq,
sparse=m.sparse,
)
# overwrite
setattr(
model.get_submodule(name), k,
klass( **kwargs ).to(device=device, dtype=dtype)
)
if verbose:
print(f"Replacing {name}.{k} to", klass)
return model
# cannot feasibly do default arguments here sad
def replace_attention( model, klass, target, mode="math", verbose=False ):
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
for *parent, k in modules:
name = '.'.join(parent)
m = getattr( model.get_submodule(name), k )
if isinstance(m, klass):
continue
2024-04-16 15:19:02 +00:00
kwargs = dict(
config = m.config,
layer_idx = m.layer_idx,
mode = mode,
)
# overwrite
setattr(
model.get_submodule(name), k,
klass( **kwargs ).to(device=device, dtype=dtype)
)
if verbose:
print(f"Replacing {name}.{k} to", klass)
2024-04-16 15:19:02 +00:00
return model
# trim/expand a tensor (for example, in a state dict)
def resize_weight( weight, target, dim=0, random=True ):
# trim
if target < weight.shape[dim]:
return weight[:target]
# expand
if target > weight.shape[dim]:
fn = torch.rand if random else torch.zeros
return torch.stack(
[ x for x in weight ] +
[ fn( weight[0].shape ).to(device=weight[0].device, dtype=weight[0].dtype) for _ in range( target - weight.shape[dim] ) ]
)
return weight
# https://github.com/konstmish/prodigy
try:
from prodigyopt import Prodigy
except Exception as e:
print('Error while importing Prodigyopt:', str(e))
pass
# https://github.com/facebookresearch/schedule_free/
try:
import schedulefree
except Exception as e:
print('Error while importing Schedule_Free:', str(e))
pass