added flash_attn LlamaAttention (including flash_attn==1.0.9)
This commit is contained in:
parent
054d28573a
commit
d636edd3a2
|
@ -159,6 +159,8 @@ For audio backends:
|
|||
* `sdpa`: integrated `LlamaSdpaAttention` attention model
|
||||
* `flash_attention_2`: integrated `LlamaFlashAttetion2` attention model
|
||||
* `auto`: determine the best fit from the above
|
||||
* `eager`: default `LlamaAttention`
|
||||
* `flash_attn`: uses the available `flash_attn` package (including `flash_attn==1.0.9` through a funny wrapper)
|
||||
|
||||
The wide support for various backends is solely while I try and figure out which is the "best" for a core foundation model.
|
||||
|
||||
|
|
|
@ -1332,24 +1332,28 @@ def create_dataset_metadata( skip_existing=True ):
|
|||
|
||||
if not os.path.isdir(f'{root}/{name}/'):
|
||||
return
|
||||
|
||||
# tqdm.write(f'{root}/{name}')
|
||||
files = os.listdir(f'{root}/{name}/')
|
||||
|
||||
# grab IDs for every file
|
||||
ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files }
|
||||
|
||||
wrote = False
|
||||
|
||||
for id in tqdm(ids, desc=f"Processing {name}"):
|
||||
try:
|
||||
quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
|
||||
text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True
|
||||
quant_path = Path(f'{root}/{name}/{id}{_get_quant_extension()}')
|
||||
|
||||
if not quant_exists:
|
||||
if audios and not quant_path.exists():
|
||||
continue
|
||||
|
||||
key = f'{type}/{speaker_name}/{id}'
|
||||
|
||||
if skip_existing and id in metadata:
|
||||
continue
|
||||
|
||||
wrote = True
|
||||
|
||||
if id not in metadata:
|
||||
metadata[id] = {}
|
||||
|
@ -1357,7 +1361,7 @@ def create_dataset_metadata( skip_existing=True ):
|
|||
utterance_metadata = {}
|
||||
if audios:
|
||||
# ideally we'll encode Encodec-based audio in a similar manner because np has smaller files than pt
|
||||
dac = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()]
|
||||
dac = np.load(quant_path, allow_pickle=True)[()]
|
||||
qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16)
|
||||
|
||||
if "text" in dac["metadata"]:
|
||||
|
@ -1368,9 +1372,6 @@ def create_dataset_metadata( skip_existing=True ):
|
|||
utterance_metadata["language"] = dac["metadata"]["language"]
|
||||
if "original_length" in dac["metadata"] and "sample_rate" in dac["metadata"]:
|
||||
utterance_metadata["duration"] = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"]
|
||||
# text
|
||||
if texts and text_exists and not utterance_metadata:
|
||||
utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
|
||||
|
||||
for k, v in utterance_metadata.items():
|
||||
metadata[id][k] = v
|
||||
|
@ -1378,8 +1379,9 @@ def create_dataset_metadata( skip_existing=True ):
|
|||
except Exception as e:
|
||||
tqdm.write(f'Error while processing {id}: {e}')
|
||||
|
||||
with open(str(metadata_path), "w", encoding="utf-8") as f:
|
||||
f.write( json.dumps( metadata ) )
|
||||
if wrote:
|
||||
with open(str(metadata_path), "w", encoding="utf-8") as f:
|
||||
f.write( json.dumps( metadata ) )
|
||||
|
||||
# training
|
||||
for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"):
|
||||
|
|
|
@ -17,7 +17,56 @@ try:
|
|||
if is_flash_attn_2_available():
|
||||
AVAILABLE_ATTENTIONS.append("flash_attention_2")
|
||||
except Exception as e:
|
||||
print("Error while querying for `flash_attn_2` support", e)
|
||||
print("Error while querying for `flash_attention_2` support", e)
|
||||
|
||||
# Borrowed from https://github.com/turboderp/exllamav2/blob/master/exllamav2/attn.py#L32
|
||||
# Adapted to provide flash_attn_v1 support
|
||||
try:
|
||||
import flash_attn
|
||||
flash_attn_ver = [int(t) for t in flash_attn.__version__.split(".") if t.isdigit()]
|
||||
is_ampere_or_newer_gpu = any(torch.cuda.get_device_properties(i).major >= 8 for i in range(torch.cuda.device_count()))
|
||||
|
||||
if [1, 0, 9] == flash_attn_ver:
|
||||
AVAILABLE_ATTENTIONS.append("flash_attn")
|
||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
||||
from einops import rearrange
|
||||
|
||||
# converts the flash_attn_2 calling convention to flash_attn_1's
|
||||
def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, return_attn_probs=False, deterministic=False, *args, **kwargs):
|
||||
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
||||
seqlen_k = k.shape[1]
|
||||
q, k, v = [rearrange(x, 'b s ... -> (b s) ...').contiguous() for x in [q, k, v]]
|
||||
|
||||
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q.device)
|
||||
cu_seqlens_k = cu_seqlens_q
|
||||
|
||||
return flash_attn_unpadded_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
|
||||
dropout_p, softmax_scale, causal, return_attn_probs, deterministic
|
||||
)
|
||||
|
||||
has_flash_attn = True
|
||||
elif [2, 2, 1] <= flash_attn_ver < [2, 5, 7]:
|
||||
AVAILABLE_ATTENTIONS.append("flash_attn")
|
||||
from flash_attn import flash_attn_func
|
||||
has_flash_attn = True
|
||||
elif [2, 5, 7] <= flash_attn_ver:
|
||||
AVAILABLE_ATTENTIONS.append("flash_attn")
|
||||
from flash_attn import flash_attn_func, flash_attn_with_kvcache
|
||||
|
||||
signature = list(inspect.signature(flash_attn_func).parameters)
|
||||
has_flash_attn_with_window = "window_size" in signature
|
||||
has_flash_attn_with_softcap = "softcap" in signature
|
||||
|
||||
import flash_attn_2_cuda as flash_attn_cuda
|
||||
|
||||
has_flash_attn = True
|
||||
has_flash_attn_with_paged = True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print("Error while querying for `flash_attn` | support", e)
|
||||
|
||||
"""
|
||||
try:
|
||||
|
@ -128,16 +177,25 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
|
||||
#with torch.backends.cuda.sdp_kernel(enable_flash=self.mode == "flash", enable_math=self.mode == "math", enable_mem_efficient=self.mode == "mem_efficient"):
|
||||
with torch.nn.attention.sdpa_kernel(self.mode):
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
if self.mode == "flash_attn":
|
||||
attn_output = flash_attn_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
causal=True,
|
||||
softmax_scale=None, # 1, / math.sqrt(cfg.head_dim),
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
else:
|
||||
with torch.nn.attention.sdpa_kernel(self.mode):
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
|
|
|
@ -8,6 +8,11 @@ from transformers.cache_utils import Cache
|
|||
from transformers import MixtralModel, MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func, MixtralSparseMoeBlock, MixtralAttention, apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
try:
|
||||
from .llama import flash_attn_func
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# This is required because batch sizes > 1 throws errors
|
||||
def MixtralSparseMoeBlock_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
""" """
|
||||
|
@ -139,15 +144,25 @@ class MixtralAttention_Adapted(MixtralAttention):
|
|||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
|
||||
#with torch.backends.cuda.sdp_kernel(enable_flash=self.mode == "flash", enable_math=self.mode == "math", enable_mem_efficient=self.mode == "mem_efficient"):
|
||||
with torch.nn.attention.sdpa_kernel(self.mode):
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
if self.mode == "flash_attn":
|
||||
attn_output = flash_attn_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
causal=True,
|
||||
softmax_scale=None, # 1, / math.sqrt(cfg.head_dim),
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
else:
|
||||
with torch.nn.attention.sdpa_kernel(self.mode):
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
||||
|
|
|
@ -524,7 +524,7 @@ class Base(nn.Module):
|
|||
|
||||
hf_attention = attention_backend
|
||||
|
||||
if attention_backend in ["xformers", "mem_efficient", "math", "flash", "cudnn"]:
|
||||
if attention_backend in ["xformers", "mem_efficient", "math", "flash", "cudnn", "flash_attn"]:
|
||||
hf_attention = None
|
||||
if attention_backend not in AVAILABLE_ATTENTIONS:
|
||||
raise ValueError(f"Requesting attention `{attention_backend}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}")
|
||||
|
@ -576,7 +576,7 @@ class Base(nn.Module):
|
|||
attn_implementation=hf_attention,
|
||||
#gradient_checkpointing=self.gradient_checkpointing,
|
||||
))
|
||||
if attention_backend in ["mem_efficient", "math", "flash", "cudnn", "auto"]:
|
||||
if attention_backend in ["mem_efficient", "math", "flash", "cudnn", "auto", "flash_attn"]:
|
||||
self.model = ml.replace_attention( self.model, klass=MixtralAttention_Adapted, target=MixtralAttention, mode=attention_backend )
|
||||
|
||||
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
|
||||
|
@ -601,7 +601,7 @@ class Base(nn.Module):
|
|||
attn_implementation=hf_attention,
|
||||
#gradient_checkpointing=self.gradient_checkpointing,
|
||||
))
|
||||
if attention_backend in ["mem_efficient", "math", "flash", "cudnn", "auto"]:
|
||||
if attention_backend in ["mem_efficient", "math", "flash", "cudnn", "auto", "flash_attn"]:
|
||||
self.model = ml.replace_attention( self.model, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend )
|
||||
else:
|
||||
self.model = MixtralModel(MixtralConfig(
|
||||
|
|
Loading…
Reference in New Issue
Block a user