From d636edd3a2e92ba32cdfe07802b1ab9484b7f8e4 Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 18 Aug 2024 20:51:14 -0500 Subject: [PATCH] added flash_attn LlamaAttention (including flash_attn==1.0.9) --- README.md | 2 + vall_e/data.py | 20 +++++----- vall_e/models/arch/llama.py | 70 ++++++++++++++++++++++++++++++++--- vall_e/models/arch/mixtral.py | 23 ++++++++++-- vall_e/models/base.py | 6 +-- 5 files changed, 99 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index 15ffdd7..9bb6381 100755 --- a/README.md +++ b/README.md @@ -159,6 +159,8 @@ For audio backends: * `sdpa`: integrated `LlamaSdpaAttention` attention model * `flash_attention_2`: integrated `LlamaFlashAttetion2` attention model * `auto`: determine the best fit from the above +* `eager`: default `LlamaAttention` +* `flash_attn`: uses the available `flash_attn` package (including `flash_attn==1.0.9` through a funny wrapper) The wide support for various backends is solely while I try and figure out which is the "best" for a core foundation model. diff --git a/vall_e/data.py b/vall_e/data.py index c01d01a..ed9ae76 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -1332,24 +1332,28 @@ def create_dataset_metadata( skip_existing=True ): if not os.path.isdir(f'{root}/{name}/'): return + # tqdm.write(f'{root}/{name}') files = os.listdir(f'{root}/{name}/') # grab IDs for every file ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files } + wrote = False + for id in tqdm(ids, desc=f"Processing {name}"): try: - quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True - text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True + quant_path = Path(f'{root}/{name}/{id}{_get_quant_extension()}') - if not quant_exists: + if audios and not quant_path.exists(): continue key = f'{type}/{speaker_name}/{id}' if skip_existing and id in metadata: continue + + wrote = True if id not in metadata: metadata[id] = {} @@ -1357,7 +1361,7 @@ def create_dataset_metadata( skip_existing=True ): utterance_metadata = {} if audios: # ideally we'll encode Encodec-based audio in a similar manner because np has smaller files than pt - dac = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()] + dac = np.load(quant_path, allow_pickle=True)[()] qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16) if "text" in dac["metadata"]: @@ -1368,9 +1372,6 @@ def create_dataset_metadata( skip_existing=True ): utterance_metadata["language"] = dac["metadata"]["language"] if "original_length" in dac["metadata"] and "sample_rate" in dac["metadata"]: utterance_metadata["duration"] = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"] - # text - if texts and text_exists and not utterance_metadata: - utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read()) for k, v in utterance_metadata.items(): metadata[id][k] = v @@ -1378,8 +1379,9 @@ def create_dataset_metadata( skip_existing=True ): except Exception as e: tqdm.write(f'Error while processing {id}: {e}') - with open(str(metadata_path), "w", encoding="utf-8") as f: - f.write( json.dumps( metadata ) ) + if wrote: + with open(str(metadata_path), "w", encoding="utf-8") as f: + f.write( json.dumps( metadata ) ) # training for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"): diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index b92e145..1088220 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -17,7 +17,56 @@ try: if is_flash_attn_2_available(): AVAILABLE_ATTENTIONS.append("flash_attention_2") except Exception as e: - print("Error while querying for `flash_attn_2` support", e) + print("Error while querying for `flash_attention_2` support", e) + +# Borrowed from https://github.com/turboderp/exllamav2/blob/master/exllamav2/attn.py#L32 +# Adapted to provide flash_attn_v1 support +try: + import flash_attn + flash_attn_ver = [int(t) for t in flash_attn.__version__.split(".") if t.isdigit()] + is_ampere_or_newer_gpu = any(torch.cuda.get_device_properties(i).major >= 8 for i in range(torch.cuda.device_count())) + + if [1, 0, 9] == flash_attn_ver: + AVAILABLE_ATTENTIONS.append("flash_attn") + from flash_attn.flash_attn_interface import flash_attn_unpadded_func + from einops import rearrange + + # converts the flash_attn_2 calling convention to flash_attn_1's + def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, return_attn_probs=False, deterministic=False, *args, **kwargs): + batch_size, seqlen_q = q.shape[0], q.shape[1] + seqlen_k = k.shape[1] + q, k, v = [rearrange(x, 'b s ... -> (b s) ...').contiguous() for x in [q, k, v]] + + cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q.device) + cu_seqlens_k = cu_seqlens_q + + return flash_attn_unpadded_func( + q, k, v, + cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k, + dropout_p, softmax_scale, causal, return_attn_probs, deterministic + ) + + has_flash_attn = True + elif [2, 2, 1] <= flash_attn_ver < [2, 5, 7]: + AVAILABLE_ATTENTIONS.append("flash_attn") + from flash_attn import flash_attn_func + has_flash_attn = True + elif [2, 5, 7] <= flash_attn_ver: + AVAILABLE_ATTENTIONS.append("flash_attn") + from flash_attn import flash_attn_func, flash_attn_with_kvcache + + signature = list(inspect.signature(flash_attn_func).parameters) + has_flash_attn_with_window = "window_size" in signature + has_flash_attn_with_softcap = "softcap" in signature + + import flash_attn_2_cuda as flash_attn_cuda + + has_flash_attn = True + has_flash_attn_with_paged = True + + +except Exception as e: + print("Error while querying for `flash_attn` | support", e) """ try: @@ -128,16 +177,25 @@ class LlamaAttention_Adapted(LlamaAttention): # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and q_len > 1 else False - #with torch.backends.cuda.sdp_kernel(enable_flash=self.mode == "flash", enable_math=self.mode == "math", enable_mem_efficient=self.mode == "mem_efficient"): - with torch.nn.attention.sdpa_kernel(self.mode): - attn_output = torch.nn.functional.scaled_dot_product_attention( + if self.mode == "flash_attn": + attn_output = flash_attn_func( query_states, key_states, value_states, - attn_mask=causal_mask, + causal=True, + softmax_scale=None, # 1, / math.sqrt(cfg.head_dim), dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, ) + else: + with torch.nn.attention.sdpa_kernel(self.mode): + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, -1) diff --git a/vall_e/models/arch/mixtral.py b/vall_e/models/arch/mixtral.py index e4f6c12..d9b486d 100644 --- a/vall_e/models/arch/mixtral.py +++ b/vall_e/models/arch/mixtral.py @@ -8,6 +8,11 @@ from transformers.cache_utils import Cache from transformers import MixtralModel, MixtralConfig from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func, MixtralSparseMoeBlock, MixtralAttention, apply_rotary_pos_emb, repeat_kv +try: + from .llama import flash_attn_func +except Exception as e: + pass + # This is required because batch sizes > 1 throws errors def MixtralSparseMoeBlock_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ """ @@ -139,15 +144,25 @@ class MixtralAttention_Adapted(MixtralAttention): is_causal = True if causal_mask is None and q_len > 1 else False #with torch.backends.cuda.sdp_kernel(enable_flash=self.mode == "flash", enable_math=self.mode == "math", enable_mem_efficient=self.mode == "mem_efficient"): - with torch.nn.attention.sdpa_kernel(self.mode): - attn_output = torch.nn.functional.scaled_dot_product_attention( + if self.mode == "flash_attn": + attn_output = flash_attn_func( query_states, key_states, value_states, - attn_mask=causal_mask, + causal=True, + softmax_scale=None, # 1, / math.sqrt(cfg.head_dim), dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, ) + else: + with torch.nn.attention.sdpa_kernel(self.mode): + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, self.hidden_size) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 8987e58..cf5a202 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -524,7 +524,7 @@ class Base(nn.Module): hf_attention = attention_backend - if attention_backend in ["xformers", "mem_efficient", "math", "flash", "cudnn"]: + if attention_backend in ["xformers", "mem_efficient", "math", "flash", "cudnn", "flash_attn"]: hf_attention = None if attention_backend not in AVAILABLE_ATTENTIONS: raise ValueError(f"Requesting attention `{attention_backend}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}") @@ -576,7 +576,7 @@ class Base(nn.Module): attn_implementation=hf_attention, #gradient_checkpointing=self.gradient_checkpointing, )) - if attention_backend in ["mem_efficient", "math", "flash", "cudnn", "auto"]: + if attention_backend in ["mem_efficient", "math", "flash", "cudnn", "auto", "flash_attn"]: self.model = ml.replace_attention( self.model, klass=MixtralAttention_Adapted, target=MixtralAttention, mode=attention_backend ) if self.gradient_checkpointing and not self.model.gradient_checkpointing: @@ -601,7 +601,7 @@ class Base(nn.Module): attn_implementation=hf_attention, #gradient_checkpointing=self.gradient_checkpointing, )) - if attention_backend in ["mem_efficient", "math", "flash", "cudnn", "auto"]: + if attention_backend in ["mem_efficient", "math", "flash", "cudnn", "auto", "flash_attn"]: self.model = ml.replace_attention( self.model, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend ) else: self.model = MixtralModel(MixtralConfig(