added flash_attn LlamaAttention (including flash_attn==1.0.9)

This commit is contained in:
mrq 2024-08-18 20:51:14 -05:00
parent 054d28573a
commit d636edd3a2
5 changed files with 99 additions and 22 deletions

View File

@ -159,6 +159,8 @@ For audio backends:
* `sdpa`: integrated `LlamaSdpaAttention` attention model
* `flash_attention_2`: integrated `LlamaFlashAttetion2` attention model
* `auto`: determine the best fit from the above
* `eager`: default `LlamaAttention`
* `flash_attn`: uses the available `flash_attn` package (including `flash_attn==1.0.9` through a funny wrapper)
The wide support for various backends is solely while I try and figure out which is the "best" for a core foundation model.

View File

@ -1332,24 +1332,28 @@ def create_dataset_metadata( skip_existing=True ):
if not os.path.isdir(f'{root}/{name}/'):
return
# tqdm.write(f'{root}/{name}')
files = os.listdir(f'{root}/{name}/')
# grab IDs for every file
ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files }
wrote = False
for id in tqdm(ids, desc=f"Processing {name}"):
try:
quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True
quant_path = Path(f'{root}/{name}/{id}{_get_quant_extension()}')
if not quant_exists:
if audios and not quant_path.exists():
continue
key = f'{type}/{speaker_name}/{id}'
if skip_existing and id in metadata:
continue
wrote = True
if id not in metadata:
metadata[id] = {}
@ -1357,7 +1361,7 @@ def create_dataset_metadata( skip_existing=True ):
utterance_metadata = {}
if audios:
# ideally we'll encode Encodec-based audio in a similar manner because np has smaller files than pt
dac = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()]
dac = np.load(quant_path, allow_pickle=True)[()]
qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16)
if "text" in dac["metadata"]:
@ -1368,9 +1372,6 @@ def create_dataset_metadata( skip_existing=True ):
utterance_metadata["language"] = dac["metadata"]["language"]
if "original_length" in dac["metadata"] and "sample_rate" in dac["metadata"]:
utterance_metadata["duration"] = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"]
# text
if texts and text_exists and not utterance_metadata:
utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
for k, v in utterance_metadata.items():
metadata[id][k] = v
@ -1378,8 +1379,9 @@ def create_dataset_metadata( skip_existing=True ):
except Exception as e:
tqdm.write(f'Error while processing {id}: {e}')
with open(str(metadata_path), "w", encoding="utf-8") as f:
f.write( json.dumps( metadata ) )
if wrote:
with open(str(metadata_path), "w", encoding="utf-8") as f:
f.write( json.dumps( metadata ) )
# training
for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"):

View File

@ -17,7 +17,56 @@ try:
if is_flash_attn_2_available():
AVAILABLE_ATTENTIONS.append("flash_attention_2")
except Exception as e:
print("Error while querying for `flash_attn_2` support", e)
print("Error while querying for `flash_attention_2` support", e)
# Borrowed from https://github.com/turboderp/exllamav2/blob/master/exllamav2/attn.py#L32
# Adapted to provide flash_attn_v1 support
try:
import flash_attn
flash_attn_ver = [int(t) for t in flash_attn.__version__.split(".") if t.isdigit()]
is_ampere_or_newer_gpu = any(torch.cuda.get_device_properties(i).major >= 8 for i in range(torch.cuda.device_count()))
if [1, 0, 9] == flash_attn_ver:
AVAILABLE_ATTENTIONS.append("flash_attn")
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
from einops import rearrange
# converts the flash_attn_2 calling convention to flash_attn_1's
def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, return_attn_probs=False, deterministic=False, *args, **kwargs):
batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k = k.shape[1]
q, k, v = [rearrange(x, 'b s ... -> (b s) ...').contiguous() for x in [q, k, v]]
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q.device)
cu_seqlens_k = cu_seqlens_q
return flash_attn_unpadded_func(
q, k, v,
cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
dropout_p, softmax_scale, causal, return_attn_probs, deterministic
)
has_flash_attn = True
elif [2, 2, 1] <= flash_attn_ver < [2, 5, 7]:
AVAILABLE_ATTENTIONS.append("flash_attn")
from flash_attn import flash_attn_func
has_flash_attn = True
elif [2, 5, 7] <= flash_attn_ver:
AVAILABLE_ATTENTIONS.append("flash_attn")
from flash_attn import flash_attn_func, flash_attn_with_kvcache
signature = list(inspect.signature(flash_attn_func).parameters)
has_flash_attn_with_window = "window_size" in signature
has_flash_attn_with_softcap = "softcap" in signature
import flash_attn_2_cuda as flash_attn_cuda
has_flash_attn = True
has_flash_attn_with_paged = True
except Exception as e:
print("Error while querying for `flash_attn` | support", e)
"""
try:
@ -128,16 +177,25 @@ class LlamaAttention_Adapted(LlamaAttention):
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
#with torch.backends.cuda.sdp_kernel(enable_flash=self.mode == "flash", enable_math=self.mode == "math", enable_mem_efficient=self.mode == "mem_efficient"):
with torch.nn.attention.sdpa_kernel(self.mode):
attn_output = torch.nn.functional.scaled_dot_product_attention(
if self.mode == "flash_attn":
attn_output = flash_attn_func(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
causal=True,
softmax_scale=None, # 1, / math.sqrt(cfg.head_dim),
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
else:
with torch.nn.attention.sdpa_kernel(self.mode):
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)

View File

@ -8,6 +8,11 @@ from transformers.cache_utils import Cache
from transformers import MixtralModel, MixtralConfig
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func, MixtralSparseMoeBlock, MixtralAttention, apply_rotary_pos_emb, repeat_kv
try:
from .llama import flash_attn_func
except Exception as e:
pass
# This is required because batch sizes > 1 throws errors
def MixtralSparseMoeBlock_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """
@ -139,15 +144,25 @@ class MixtralAttention_Adapted(MixtralAttention):
is_causal = True if causal_mask is None and q_len > 1 else False
#with torch.backends.cuda.sdp_kernel(enable_flash=self.mode == "flash", enable_math=self.mode == "math", enable_mem_efficient=self.mode == "mem_efficient"):
with torch.nn.attention.sdpa_kernel(self.mode):
attn_output = torch.nn.functional.scaled_dot_product_attention(
if self.mode == "flash_attn":
attn_output = flash_attn_func(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
causal=True,
softmax_scale=None, # 1, / math.sqrt(cfg.head_dim),
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
else:
with torch.nn.attention.sdpa_kernel(self.mode):
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)

View File

@ -524,7 +524,7 @@ class Base(nn.Module):
hf_attention = attention_backend
if attention_backend in ["xformers", "mem_efficient", "math", "flash", "cudnn"]:
if attention_backend in ["xformers", "mem_efficient", "math", "flash", "cudnn", "flash_attn"]:
hf_attention = None
if attention_backend not in AVAILABLE_ATTENTIONS:
raise ValueError(f"Requesting attention `{attention_backend}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}")
@ -576,7 +576,7 @@ class Base(nn.Module):
attn_implementation=hf_attention,
#gradient_checkpointing=self.gradient_checkpointing,
))
if attention_backend in ["mem_efficient", "math", "flash", "cudnn", "auto"]:
if attention_backend in ["mem_efficient", "math", "flash", "cudnn", "auto", "flash_attn"]:
self.model = ml.replace_attention( self.model, klass=MixtralAttention_Adapted, target=MixtralAttention, mode=attention_backend )
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
@ -601,7 +601,7 @@ class Base(nn.Module):
attn_implementation=hf_attention,
#gradient_checkpointing=self.gradient_checkpointing,
))
if attention_backend in ["mem_efficient", "math", "flash", "cudnn", "auto"]:
if attention_backend in ["mem_efficient", "math", "flash", "cudnn", "auto", "flash_attn"]:
self.model = ml.replace_attention( self.model, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend )
else:
self.model = MixtralModel(MixtralConfig(