vall-e/vall_e/utils/wrapper.py

60 lines
1.5 KiB
Python
Raw Normal View History

2023-08-02 21:53:35 +00:00
# to-do: re-introduce bitsandbytes support
from contextlib import contextmanager
import torch
import torch.nn.functional as F
Embedding = torch.nn.Embedding
Linear = torch.nn.Linear
"""
if cfg.bitsandbytes:
import bitsandbytes as bnb
if cfg.bitsandbytes_linear:
Linear = bnb.nn.Linear8bitLt
if cfg.bitsandbytes_embedding:
Embedding = bnb.nn.StableEmbedding
Embedding.forward = lambda self, input: ( self.norm(F.embedding(
input,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)).to(self.weight.dtype) )
"""
Adam = torch.optim.Adam
AdamW = torch.optim.AdamW
"""
if cfg.bitsandbytes:
import bitsandbytes as bnb
Adam = bnb.optim.Adam
AdamW = bnb.optim.AdamW
"""
# handles temporarily upcasting 'index tensors' so torch will stop bitching
def autocast_forward( func ):
def wrapper( self, input, *args, **kwargs ):
if input.dtype == torch.int16 or input.dtype == torch.int8 or input.dtype == torch.uint8:
input = input.to(torch.int32)
return func( self, input, *args, **kwargs )
return wrapper
Embedding.forward = autocast_forward(Embedding.forward)
# handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16)
@contextmanager
def autocast(input, from_dtype, to_dtype):
if input.dtype == from_dtype:
input = input.to(to_dtype)
yield input
input = input.to(from_dtype)
else:
yield input