Test for bloom that fails with inference kernels.
This commit is contained in:
parent
ae7cd6ad14
commit
dc96e9e7c8
|
@ -2,6 +2,9 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
from itertools import product
|
||||||
|
|
||||||
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
|
@ -11,7 +14,7 @@ from transformers import (
|
||||||
set_seed,
|
set_seed,
|
||||||
|
|
||||||
)
|
)
|
||||||
import transformers
|
|
||||||
|
|
||||||
|
|
||||||
def get_4bit_config():
|
def get_4bit_config():
|
||||||
|
@ -26,15 +29,23 @@ def get_4bit_config():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_model(model_name_or_path='huggyllama/llama-7b', bnb_config=get_4bit_config()):
|
def get_model_and_tokenizer(config):
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model_name_or_path, quant_type = config
|
||||||
model_name_or_path,
|
bnb_config = get_4bit_config()
|
||||||
|
if quant_type == '16bit':
|
||||||
|
bnb_config.load_in_4bit = False
|
||||||
|
else:
|
||||||
|
bnb_config.bnb_4bit_quant_type= quant_type
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
|
||||||
quantization_config=bnb_config,
|
quantization_config=bnb_config,
|
||||||
max_memory={0:'48GB'},
|
max_memory={0:'48GB'},
|
||||||
device_map='auto'
|
device_map='auto',
|
||||||
|
torch_dtype=torch.bfloat16
|
||||||
).eval()
|
).eval()
|
||||||
|
|
||||||
return model
|
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path)
|
||||||
|
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
def get_prompt_for_generation_eval(text, add_roles=True):
|
def get_prompt_for_generation_eval(text, add_roles=True):
|
||||||
description = (
|
description = (
|
||||||
|
@ -53,48 +64,66 @@ def generate(model, tokenizer, text, generation_config, prompt_func=get_prompt_f
|
||||||
outputs = model.generate(inputs=inputs['input_ids'], generation_config=generation_config)
|
outputs = model.generate(inputs=inputs['input_ids'], generation_config=generation_config)
|
||||||
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||||
|
|
||||||
name_or_path = 'huggyllama/llama-7b'
|
models = ['huggyllama/llama-7b', 'bigscience/bloom-1b7']
|
||||||
#name_or_path = 'AI-Sweden/gpt-sw3-126m'
|
dtypes = ['nf4', 'fp4', '16bit']
|
||||||
|
load_in_4bit = [True, False]
|
||||||
@pytest.fixture(scope='session')
|
values = list(product(models, dtypes))
|
||||||
def model():
|
strfunc = lambda lst: [str(x) for x in lst]
|
||||||
bnb_config = get_4bit_config()
|
ids = ['_'.join(strfunc(x)) for x in values]
|
||||||
bnb_config.bnb_4bit_compute_dtype=torch.float32
|
@pytest.fixture(scope='session', params=values, ids=ids)
|
||||||
bnb_config.load_in_4bit=True
|
def model_and_tokenizer(request):
|
||||||
model = get_model(name_or_path)
|
model, tokenizer = get_model_and_tokenizer(request.param)
|
||||||
print('')
|
yield model, tokenizer
|
||||||
return model
|
del model
|
||||||
|
|
||||||
@pytest.fixture(scope='session')
|
|
||||||
def tokenizer():
|
|
||||||
tokenizer = transformers.AutoTokenizer.from_pretrained(name_or_path)
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("inference_kernel", [True, False], ids=['inference_kernel_True', 'inference_kernel_False'])
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32'])
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32'])
|
||||||
def test_pi(model, tokenizer, dtype):
|
def test_pi(model_and_tokenizer, dtype, inference_kernel):
|
||||||
|
|
||||||
|
model, tokenizer = model_and_tokenizer
|
||||||
|
|
||||||
generation_config = transformers.GenerationConfig(
|
generation_config = transformers.GenerationConfig(
|
||||||
max_new_tokens=128,
|
max_new_tokens=20,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
top_p=0.9,
|
top_p=0.9,
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
)
|
)
|
||||||
generation_config.max_new_tokens = 50
|
generation_config.max_new_tokens = 20
|
||||||
|
|
||||||
|
|
||||||
#text = 'Please write down the first 50 digits of pi.'
|
#text = 'Please write down the first 50 digits of pi.'
|
||||||
#text = get_prompt_for_generation_eval(text)
|
#text = get_prompt_for_generation_eval(text)
|
||||||
#text += ' Sure, here the first 50 digits of pi: 3.14159'
|
#text += ' Sure, here the first 50 digits of pi: 3.14159'
|
||||||
|
n_cases = 3
|
||||||
text = '3.14159'
|
text = '3.14159'
|
||||||
|
if hasattr(model.config, 'quantization_config'):
|
||||||
model.config.quantization_config.bnb_4bit_compute_dtype = dtype
|
model.config.quantization_config.bnb_4bit_compute_dtype = dtype
|
||||||
|
|
||||||
|
if not inference_kernel:
|
||||||
|
text = [text]*n_cases
|
||||||
inputs = tokenizer(text, return_tensors="pt").to('cuda:0')
|
inputs = tokenizer(text, return_tensors="pt").to('cuda:0')
|
||||||
outputs = model.generate(inputs=inputs['input_ids'], generation_config=generation_config)
|
x = inputs['input_ids']
|
||||||
textout = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
failure_count = 0
|
||||||
print('')
|
outputs = []
|
||||||
print(textout)
|
if inference_kernel:
|
||||||
|
for i in range(n_cases):
|
||||||
|
output = model.generate(x, generation_config=generation_config)
|
||||||
|
textout = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||||
|
outputs.append(textout)
|
||||||
|
else:
|
||||||
|
outputs = model.generate(x, generation_config=generation_config)
|
||||||
|
outputs = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
|
||||||
|
|
||||||
|
|
||||||
|
assert len(outputs) == n_cases
|
||||||
|
for i in range(n_cases):
|
||||||
|
if not outputs[i][:len(str(math.pi))] == str(math.pi):
|
||||||
|
failure_count += 1
|
||||||
|
if failure_count > 1:
|
||||||
print(math.pi)
|
print(math.pi)
|
||||||
|
for out in outputs:
|
||||||
assert textout[:len(str(math.pi))] == str(math.pi)
|
print(out)
|
||||||
|
raise ValueError(f'Failure count: {failure_count}/{n_cases}')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user