Refactored simulated fp8 modules into research.nn.
parent
e67bfccbcd
commit
dd562c24f1
@ -0,0 +1 @@
|
||||
from .modules import LinearFP8Mixed, LinearFP8Global
|
@ -0,0 +1,64 @@
|
||||
from typing import Optional, TypeVar, Union, overload
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, device, dtype, nn
|
||||
|
||||
import bitsandbytes as bnb
|
||||
from bitsandbytes.optim import GlobalOptimManager
|
||||
from bitsandbytes.utils import OutlierTracer, find_outlier_dims
|
||||
|
||||
T = TypeVar("T", bound="torch.nn.Module")
|
||||
|
||||
|
||||
class LinearFP8Mixed(nn.Linear):
|
||||
def __init__(self, input_features, output_features, bias=True):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
self.bw_code = None
|
||||
self.fw_code = None
|
||||
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
|
||||
for i, k in enumerate(array):
|
||||
if input_features > array[i + 1]:
|
||||
self.bsz = k
|
||||
break
|
||||
for i, k in enumerate(array):
|
||||
if output_features > array[i + 1]:
|
||||
self.bsz2 = k
|
||||
break
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.fw_code is None:
|
||||
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
|
||||
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
|
||||
|
||||
out = bnb.research.matmul_fp8_mixed(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2)
|
||||
if self.bias is not None:
|
||||
out += self.bias
|
||||
|
||||
return out
|
||||
|
||||
class LinearFP8Global(nn.Linear):
|
||||
def __init__(self, input_features, output_features, bias=True):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
self.bw_code = None
|
||||
self.fw_code = None
|
||||
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
|
||||
for i, k in enumerate(array):
|
||||
if input_features > array[i + 1]:
|
||||
self.bsz = k
|
||||
break
|
||||
for i, k in enumerate(array):
|
||||
if output_features > array[i + 1]:
|
||||
self.bsz2 = k
|
||||
break
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.fw_code is None:
|
||||
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
|
||||
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
|
||||
|
||||
out = bnb.matmul_fp8_global(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2)
|
||||
if self.bias is not None:
|
||||
out += self.bias
|
||||
|
||||
return out
|
@ -0,0 +1,27 @@
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
MAX_NEW_TOKENS = 128
|
||||
model_name = 'decapoda-research/llama-7b-hf'
|
||||
|
||||
text = 'Hamburg is in which country?\n'
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
input_ids = tokenizer(text, return_tensors="pt").input_ids
|
||||
|
||||
free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3)
|
||||
max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
|
||||
|
||||
n_gpus = torch.cuda.device_count()
|
||||
max_memory = {i: max_memory for i in range(n_gpus)}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
device_map='auto',
|
||||
load_in_8bit=True,
|
||||
max_memory=max_memory
|
||||
)
|
||||
generated_ids = model.generate(input_ids, max_length=MAX_NEW_TOKENS)
|
||||
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue