diff --git a/README.md b/README.md index 9bb6381..87fad22 100755 --- a/README.md +++ b/README.md @@ -147,23 +147,43 @@ For audio backends: * [`encodec`](https://github.com/facebookresearch/encodec): a tried-and-tested EnCodec to encode/decode audio. * [`vocos`](https://huggingface.co/charactr/vocos-encodec-24khz): a higher quality EnCodec decoder. - encoding audio will use the `encodec` backend automagically, as there's no EnCodec encoder under `vocos` -* [`descript-audio-codec`](https://github.com/descriptinc/descript-audio-codec): boasts better compression and quality +* [`descript-audio-codec`](https://github.com/descriptinc/descript-audio-codec): boasts better compression and quality, but has issues with model convergence. - models at 24KHz + 8kbps will NOT converge in any manner. - models at 44KHz + 8kbps seems harder to model its "language", and the NAR side of the model suffers greatly. `llama`-based models also support different attention backends: -* `math`: torch's SDPA's `math` implementation -* `mem_efficient`: torch's SDPA's memory efficient (`xformers` adjacent) implementation -* `flash`: torch's SDPA's flash attention implementation -* `xformers`: ~~[facebookresearch/xformers](https://github.com/facebookresearch/xformers/)'s memory efficient attention~~ Aliased to `mem_efficient` -* `sdpa`: integrated `LlamaSdpaAttention` attention model -* `flash_attention_2`: integrated `LlamaFlashAttetion2` attention model +* `torch.nn.functional.scaled_dot_product_attention`-based attention: + * `math`: torch's SDPA's `math` kernel + * `mem_efficient`: torch's SDPA's memory efficient (`xformers` adjacent) kernel + * `cudnn`: torch's SDPA's `cudnn` kernel + * `flash`: torch's SDPA's flash attention kernel +* internal implementations of external attention backends: + * `xformers`: [facebookresearch/xformers](https://github.com/facebookresearch/xformers/)'s memory efficient attention + * `flash_attn`: uses the available `flash_attn` package (including `flash_attn==1.0.9` through a funny wrapper) + * `flash_attn_v100`: uses [ZRayZzz/flash-attention-v100](https://github.com/ZRayZzz/flash-attention-v100/)'s Flash Attention for Volta (but doesn't work currently) + * `fused_attn`: uses an implementation using `triton` (only tested on my 7900XTX / Navi3 / gfx1100) +* `transformers` Llama\*Attention implementations: + * `eager`: default `LlamaAttention` + * `sdpa`: integrated `LlamaSdpaAttention` attention model + * `flash_attention_2`: integrated `LlamaFlashAttetion2` attention model * `auto`: determine the best fit from the above -* `eager`: default `LlamaAttention` -* `flash_attn`: uses the available `flash_attn` package (including `flash_attn==1.0.9` through a funny wrapper) The wide support for various backends is solely while I try and figure out which is the "best" for a core foundation model. +##### ROCm Flash Attention + +[ROCm/flash-attention](https://github.com/ROCm/flash-attention) currently does not support Navi3 cards (gfx11xx), so first-class support for Flash Attention is a bit of a mess on Navi3. Using the `howiejay/navi_support` branch can get inference support, but not training support (due to some error being thrown during the backwards pass) by: +* edit `/opt/rocm/include/hip/amd_detail/amd_hip_bf16.h`: +``` + #if defined(__HIPCC_RTC__) + #define __HOST_DEVICE__ __device__ static + #else + #include + #define __HOST_DEVICE__ __host__ __device__ static inline + #endif +``` +* install with `pip install -U git+https://github.com/ROCm/flash-attention@howiejay/navi_support --no-build-isolation` + ## Export To export the models, run: `python -m vall_e.export --yaml=./training/config.yaml`. diff --git a/vall_e/models/arch/__init__.py b/vall_e/models/arch/__init__.py index 3d37f5b..187da69 100755 --- a/vall_e/models/arch/__init__.py +++ b/vall_e/models/arch/__init__.py @@ -34,7 +34,7 @@ try: AVAILABLE_ARCHES.append("llama") except Exception as e: ERROR_ARCHES["llama"] = e - AVAILABLE_ARCHES = [] + AVAILABLE_ATTENTIONS = [] pass try: diff --git a/vall_e/models/arch/attention/fused.py b/vall_e/models/arch/attention/fused.py new file mode 100644 index 0000000..633cf3e --- /dev/null +++ b/vall_e/models/arch/attention/fused.py @@ -0,0 +1,717 @@ +# Grabbed and tweaked from https://github.com/ardfork/ComfyUI-flash-attention-triton/blob/master/fused_attention.py +# There's a bunch of other fused attentions out there and each one has its own problems it seems + +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) +Credits: OpenAI kernel team + +Extra Credits: +- Original flash attention paper (https://arxiv.org/abs/2205.14135) +- Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) + +""" + +import pytest +import torch + +import triton +import triton.language as tl + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +@triton.jit +def _attn_fwd_inner( + acc, + l_i, + m_i, + q, # + K_block_ptr, + V_block_ptr, # + start_m, + qk_scale, # + BLOCK_M: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, # + STAGE: tl.constexpr, + offs_m: tl.constexpr, + offs_n: tl.constexpr, # + N_CTX: tl.constexpr, + fp8_v: tl.constexpr, +): + # range of values handled by this stage + if STAGE == 1: + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) + # causal = False + else: + lo, hi = 0, N_CTX + K_block_ptr = tl.advance(K_block_ptr, (0, lo)) + V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(K_block_ptr) + qk = tl.dot(q, k) + if STAGE == 2: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) + qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + else: + m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) + qk = qk * qk_scale - m_ij[:, None] + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + # -- update output accumulator -- + acc = acc * alpha[:, None] + # update acc + v = tl.load(V_block_ptr) + """ + if fp8_v: + p = p.to(tl.float8e5) + else: + p = p.to(tl.float16) + """ + p = p.to(v.dtype) + acc = tl.dot(p, v, acc) + # update m_i and l_i + m_i = m_ij + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + return acc, l_i, m_i + + +# We don't run auto-tuning every time to keep the tutorial fast. Keeping +# the code below and commenting out the equivalent parameters is convenient for +# re-tuning. +configs = [ + triton.Config({"BLOCK_M": BM, "BLOCK_N": BN}, num_stages=s, num_warps=w) + for BM in [64, 128] + for BN in [32, 64] + for s in ([1] if is_hip() else [3, 4, 7]) + for w in [4, 8] +] + + +def keep(conf): + BLOCK_M = conf.kwargs["BLOCK_M"] + BLOCK_N = conf.kwargs["BLOCK_N"] + if BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8: + return False + return True + + +@triton.autotune(list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM"]) +@triton.jit +def _attn_fwd( + Q, + K, + V, + sm_scale, + M, + Out, # + stride_qz, + stride_qh, + stride_qm, + stride_qk, # + stride_kz, + stride_kh, + stride_kn, + stride_kk, # + stride_vz, + stride_vh, + stride_vk, + stride_vn, # + stride_oz, + stride_oh, + stride_om, + stride_on, # + Z, + H, + N_CTX, # + HEAD_DIM: tl.constexpr, # + BLOCK_M: tl.constexpr, # + BLOCK_N: tl.constexpr, # + STAGE: tl.constexpr, # +): + tl.static_assert(BLOCK_N <= HEAD_DIM) + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + off_z = off_hz // H + off_h = off_hz % H + qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh + + # block pointers + Q_block_ptr = tl.make_block_ptr( + base=Q + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0) + V_block_ptr = tl.make_block_ptr( + base=V + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=v_order, + ) + K_block_ptr = tl.make_block_ptr( + base=K + qvk_offset, + shape=(HEAD_DIM, N_CTX), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(HEAD_DIM, BLOCK_N), + order=(0, 1), + ) + O_block_ptr = tl.make_block_ptr( + base=Out + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + # load scales + qk_scale = sm_scale + qk_scale *= 1.44269504 # 1/log(2) + # load q: it will stay in SRAM throughout + q = tl.load(Q_block_ptr) + # stage 1: off-band + # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE + # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE + if STAGE & 1: + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, # + start_m, + qk_scale, # + BLOCK_M, + HEAD_DIM, + BLOCK_N, # + 4 - STAGE, + offs_m, + offs_n, + N_CTX, + V.dtype.element_ty == tl.float8e5, # + ) + # stage 2: on-band + if STAGE & 2: + # barrier makes it easier for compielr to schedule the + # two loops independently + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, # + start_m, + qk_scale, # + BLOCK_M, + HEAD_DIM, + BLOCK_N, # + 2, + offs_m, + offs_n, + N_CTX, + V.dtype.element_ty == tl.float8e5, # + ) + # epilogue + m_i += tl.math.log2(l_i) + acc = acc / l_i[:, None] + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(m_ptrs, m_i) + tl.store(O_block_ptr, acc.to(Out.type.element_ty)) + + +@triton.jit +def _attn_bwd_preprocess( + O, + DO, # + Delta, # + Z, + H, + N_CTX, # + BLOCK_M: tl.constexpr, + HEAD_DIM: tl.constexpr, # +): + off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_hz = tl.program_id(1) + off_n = tl.arange(0, HEAD_DIM) + # load + o = tl.load( + O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :] + ) + do = tl.load( + DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :] + ).to(tl.float32) + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(Delta + off_hz * N_CTX + off_m, delta) + + +# The main inner-loop logic for computing dK and dV. +@triton.jit +def _attn_bwd_dkdv( + dk, + dv, # + Q, + k, + v, + sm_scale, # + DO, # + M, + D, # + # shared by Q/K/V/DO. + stride_tok, + stride_d, # + H, + N_CTX, + BLOCK_M1: tl.constexpr, # + BLOCK_N1: tl.constexpr, # + HEAD_DIM: tl.constexpr, # + # Filled in by the wrapper. + start_n, + start_m, + num_steps, # + MASK: tl.constexpr, +): + offs_m = start_m + tl.arange(0, BLOCK_M1) + offs_n = start_n + tl.arange(0, BLOCK_N1) + offs_k = tl.arange(0, HEAD_DIM) + qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d + do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + curr_m = start_m + step_m = BLOCK_M1 + for blk_idx in range(num_steps): + qT = tl.load(qT_ptrs) + # Load m before computing qk to reduce pipeline stall. + offs_m = curr_m + tl.arange(0, BLOCK_M1) + m = tl.load(M + offs_m) + qkT = tl.dot(k, qT) + pT = tl.math.exp2(qkT - m[None, :]) + # Autoregressive masking. + if MASK: + mask = offs_m[None, :] >= offs_n[:, None] + pT = tl.where(mask, pT, 0.0) + do = tl.load(do_ptrs) + # Compute dV. + ppT = pT + ppT = ppT.to(do.dtype) + dv += tl.dot(ppT, do) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do)).to(tl.float32) + dsT = pT * (dpT - Di[None, :]) + dsT = dsT.to(qT.dtype) + dk += tl.dot(dsT, tl.trans(qT)) + # Increment pointers. + curr_m += step_m + qT_ptrs += step_m * stride_tok + do_ptrs += step_m * stride_tok + return dk, dv + + +# the main inner-loop logic for computing dQ +@triton.jit +def _attn_bwd_dq( + dq, + q, + K, + V, # + do, + m, + D, + # shared by Q/K/V/DO. + stride_tok, + stride_d, # + H, + N_CTX, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + HEAD_DIM: tl.constexpr, + # Filled in by the wrapper. + start_m, + start_n, + num_steps, # + MASK: tl.constexpr, +): + offs_m = start_m + tl.arange(0, BLOCK_M2) + offs_n = start_n + tl.arange(0, BLOCK_N2) + offs_k = tl.arange(0, HEAD_DIM) + kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d + vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + curr_n = start_n + step_n = BLOCK_N2 + for blk_idx in range(num_steps): + kT = tl.load(kT_ptrs) + vT = tl.load(vT_ptrs) + qk = tl.dot(q, kT) + p = tl.math.exp2(qk - m) + # Autoregressive masking. + if MASK: + offs_n = curr_n + tl.arange(0, BLOCK_N2) + mask = offs_m[:, None] >= offs_n[None, :] + p = tl.where(mask, p, 0.0) + # Compute dP and dS. + dp = tl.dot(do, vT).to(tl.float32) + ds = p * (dp - Di[:, None]) + ds = ds.to(kT.dtype) + # Compute dQ. + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + dq += tl.dot(ds, tl.trans(kT)) + # Increment pointers. + curr_n += step_n + kT_ptrs += step_n * stride_tok + vT_ptrs += step_n * stride_tok + return dq + + +@triton.jit +def _attn_bwd( + Q, + K, + V, + sm_scale, # + DO, # + DQ, + DK, + DV, # + M, + D, + # shared by Q/K/V/DO. + stride_z, + stride_h, + stride_tok, + stride_d, # + H, + N_CTX, # + BLOCK_M1: tl.constexpr, # + BLOCK_N1: tl.constexpr, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + BLK_SLICE_FACTOR: tl.constexpr, # + HEAD_DIM: tl.constexpr, +): + LN2: tl.constexpr = 0.6931471824645996 # = ln(2) + + bhid = tl.program_id(2) + off_chz = (bhid * N_CTX).to(tl.int64) + adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) + pid = tl.program_id(0) + + # offset pointers for batch/head + Q += adj + K += adj + V += adj + DO += adj + DQ += adj + DK += adj + DV += adj + M += off_chz + D += off_chz + + # load scales + offs_k = tl.arange(0, HEAD_DIM) + + start_n = pid * BLOCK_N1 + start_m = start_n + + MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR + offs_n = start_n + tl.arange(0, BLOCK_N1) + + dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) + v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) + + num_steps = BLOCK_N1 // MASK_BLOCK_M1 + + dk, dv = _attn_bwd_dkdv( + dk, + dv, # + Q, + k, + v, + sm_scale, # + DO, # + M, + D, # + stride_tok, + stride_d, # + H, + N_CTX, # + MASK_BLOCK_M1, + BLOCK_N1, + HEAD_DIM, # + start_n, + start_m, + num_steps, # + MASK=True, # + ) + + start_m += num_steps * MASK_BLOCK_M1 + num_steps = (N_CTX - start_m) // BLOCK_M1 + + # Compute dK and dV for non-masked blocks. + dk, dv = _attn_bwd_dkdv( # + dk, + dv, # + Q, + k, + v, + sm_scale, # + DO, # + M, + D, # + stride_tok, + stride_d, # + H, + N_CTX, # + BLOCK_M1, + BLOCK_N1, + HEAD_DIM, # + start_n, + start_m, + num_steps, # + MASK=False, # + ) + + dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dv_ptrs, dv) + + # Write back dK. + dk *= sm_scale + dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dk_ptrs, dk) + + # THIS BLOCK DOES DQ: + start_m = pid * BLOCK_M2 + end_n = start_m + BLOCK_M2 + + MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR + offs_m = start_m + tl.arange(0, BLOCK_M2) + + q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) + do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + + m = tl.load(M + offs_m) + m = m[:, None] + + # Compute dQ for masked (diagonal) blocks. + # NOTE: This code scans each row of QK^T backward (from right to left, + # but inside each call to _attn_bwd_dq, from left to right), but that's + # not due to anything important. I just wanted to reuse the loop + # structure for dK & dV above as much as possible. + num_steps = BLOCK_M2 // MASK_BLOCK_N2 + dq = _attn_bwd_dq( + dq, + q, + K, + V, # + do, + m, + D, # + stride_tok, + stride_d, # + H, + N_CTX, # + BLOCK_M2, + MASK_BLOCK_N2, + HEAD_DIM, # + start_m, + end_n - num_steps * MASK_BLOCK_N2, + num_steps, # + MASK=True, # + ) + end_n -= num_steps * MASK_BLOCK_N2 + # stage 2 + num_steps = end_n // BLOCK_N2 + dq = _attn_bwd_dq( + dq, + q, + K, + V, # + do, + m, + D, # + stride_tok, + stride_d, # + H, + N_CTX, # + BLOCK_M2, + BLOCK_N2, + HEAD_DIM, # + start_m, + end_n - num_steps * BLOCK_N2, + num_steps, # + MASK=False, # + ) + # Write back dQ. + dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + dq *= LN2 + tl.store(dq_ptrs, dq) + + +class _attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, causal, sm_scale): + # shape constraints + HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] + # when v is in float8_e5m2 it is transposed. + HEAD_DIM_V = v.shape[-1] + assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V + assert HEAD_DIM_K in {16, 32, 64, 128, 256} + o = torch.empty_like(q) + stage = 3 if causal else 1 + extra_kern_args = {} + # Tuning for AMD target + if is_hip(): + waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2 + extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True} + + grid = lambda args: ( + triton.cdiv(q.shape[2], args["BLOCK_M"]), + q.shape[0] * q.shape[1], + 1, + ) + M = torch.empty( + (q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32 + ) + _attn_fwd[grid]( + q, + k, + v, + sm_scale, + M, + o, # + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), # + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), # + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), # + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), # + q.shape[0], + q.shape[1], # + N_CTX=q.shape[2], # + HEAD_DIM=HEAD_DIM_K, # + STAGE=stage, # + **extra_kern_args, + ) + + ctx.save_for_backward(q, k, v, o, M) + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.HEAD_DIM = HEAD_DIM_K + ctx.causal = causal + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, M = ctx.saved_tensors + #assert do.is_contiguous() + #assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + BATCH, N_HEAD, N_CTX = q.shape[:3] + PRE_BLOCK = 128 + NUM_WARPS, NUM_STAGES = 4, 5 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 + BLK_SLICE_FACTOR = 2 + RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) + arg_k = k + arg_k = arg_k * (ctx.sm_scale * RCP_LN2) + PRE_BLOCK = 128 + assert N_CTX % PRE_BLOCK == 0 + pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD) + delta = torch.empty_like(M) + _attn_bwd_preprocess[pre_grid]( + o, + do, # + delta, # + BATCH, + N_HEAD, + N_CTX, # + BLOCK_M=PRE_BLOCK, + HEAD_DIM=ctx.HEAD_DIM, # + ) + grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) + _attn_bwd[grid]( + q, + arg_k, + v, + ctx.sm_scale, + do, + dq, + dk, + dv, # + M, + delta, # + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), # + N_HEAD, + N_CTX, # + BLOCK_M1=BLOCK_M1, + BLOCK_N1=BLOCK_N1, # + BLOCK_M2=BLOCK_M2, + BLOCK_N2=BLOCK_N2, # + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # + HEAD_DIM=ctx.HEAD_DIM, # + num_warps=NUM_WARPS, # + num_stages=NUM_STAGES, # + ) + + return dq, dk, dv, None, None + + +attention = _attention.apply \ No newline at end of file diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index f8712d1..f486e65 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -20,40 +20,25 @@ try: except Exception as e: print("Error while querying for `flash_attention_2` support", e) +try: + from .attention.fused import attention as _fused_attention + def fused_attn_func(q, k, v, softmax_scale=None, causal=False, *args, **kwargs): + return _fused_attention( q, k, v, causal, softmax_scale ) + + AVAILABLE_ATTENTIONS.append("fused_attn") +except Exception as e: + print("Error while querying for `fused_attn` support", e) + + is_rocm = any("AMD" in torch.cuda.get_device_properties(i).name for i in range(torch.cuda.device_count())) is_ampere_or_newer_gpu = any(torch.cuda.get_device_properties(i).major >= 8 for i in range(torch.cuda.device_count())) try: - if is_rocm and False: - # try to use triton flash attention / fused attention - # currently only forward works, backwards throws an assert - # even then it's extremely slow on my 7900XTX so the provided code is probably botched since it's a benchmark sample - from einops import rearrange - from .triton_flash_attention import triton_attention, MetaData - - def flash_attn_func(q, k, v, softmax_scale=None, causal=False, *args, **kwargs): - metadata = MetaData(sm_scale=softmax_scale) - batch_size, seqlen_q, seqlen_k = q.shape[0], q.shape[1], k.shape[1] - - metadata.max_seqlens_q = seqlen_q - metadata.max_seqlens_k = seqlen_k - - # varlen but doesn't seem necessary - if False: - q, k, v = [rearrange(x, 'b s ... -> (b s) ...').contiguous() for x in [q, k, v]] - - cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q.device) - cu_seqlens_k = cu_seqlens_q - - metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) - - if causal: - metadata.need_causal() - - return triton_attention( q, k, v, None, metadata )[0] - + if is_rocm: + # requires pain to set up on Navi3, and for no backwards (training) support + from flash_attn import flash_attn_func AVAILABLE_ATTENTIONS.append("flash_attn") - AVAILABLE_ATTENTIONS.append("flash_attn_rocm") + elif not is_ampere_or_newer_gpu: # Uses https://github.com/ZRayZzz/flash-attention-v100/ # Currently doesn't work because it's hard-coded to use a head dim of 128, will throw NaNs otherwise... @@ -113,6 +98,7 @@ try: has_flash_attn = True has_flash_attn_with_paged = True except Exception as e: + raise e print("Error while querying for `flash_attn` support", e) try: @@ -278,15 +264,25 @@ class LlamaAttention_Adapted(LlamaAttention): # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and q_len > 1 else False - with torch.nn.attention.sdpa_kernel(self.mode): - attn_output = torch.nn.functional.scaled_dot_product_attention( + if self.mode in ["fused_attn"]: + attn_output = fused_attn_func( query_states, key_states, value_states, - attn_mask=causal_mask, + causal=True, + softmax_scale=1.0 / math.sqrt(self.head_dim), dropout_p=dropout_rate, - is_causal=is_causal, ) + else: + with torch.nn.attention.sdpa_kernel(self.mode): + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=dropout_rate, + is_causal=is_causal, + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, -1) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index ed96d54..e4fafde 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -401,9 +401,6 @@ class Base(nn.Module): self.l_padding = l_padding - if "flash_attn_v100" in AVAILABLE_ATTENTIONS: - self.l_padding = 32 - self.ignore_index = -100 self.n_resp_levels = self.config.resp_levels if self.config else n_resp_levels @@ -521,12 +518,19 @@ class Base(nn.Module): attention_backend = "eager" hf_attention = attention_backend + HF_ATTENTIONS = ["eager", "sdpa", "flash_attention_2"] - if attention_backend in ["xformers", "mem_efficient", "math", "flash", "cudnn", "flash_attn"]: + if attention_backend not in HF_ATTENTIONS: hf_attention = None if attention_backend not in AVAILABLE_ATTENTIONS: raise ValueError(f"Requesting attention `{attention_backend}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}") + if attention_backend == "flash_attn_v100": + self.l_padding = 32 + + if attention_backend == "fused_attn": + self.l_padding = 128 + if self.arch_type == "transformer": self.sin_emb = SinusoidalEmbedding(d_model) self.blocks = nn.ModuleList([TransformerBlock( @@ -574,7 +578,7 @@ class Base(nn.Module): attn_implementation=hf_attention, #gradient_checkpointing=self.gradient_checkpointing, )) - if attention_backend in ["xformers", "mem_efficient", "math", "flash", "cudnn", "auto", "flash_attn"]: + if attention_backend not in HF_ATTENTIONS: self.model = ml.replace_attention( self.model, klass=MixtralAttention_Adapted, target=MixtralAttention, mode=attention_backend ) if self.gradient_checkpointing and not self.model.gradient_checkpointing: @@ -599,7 +603,7 @@ class Base(nn.Module): attn_implementation=hf_attention, #gradient_checkpointing=self.gradient_checkpointing, )) - if attention_backend in ["xformers", "mem_efficient", "math", "flash", "cudnn", "auto", "flash_attn"]: + if attention_backend not in HF_ATTENTIONS: self.model = ml.replace_attention( self.model, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend ) else: self.model = MixtralModel(MixtralConfig( @@ -621,7 +625,7 @@ class Base(nn.Module): attn_implementation=hf_attention, #gradient_checkpointing=self.gradient_checkpointing, )) - if attention_backend in ["xformers", "mem_efficient", "math", "flash", "cudnn", "auto", "flash_attn"]: + if attention_backend not in HF_ATTENTIONS: self.model = ml.replace_attention( self.model, klass=MixtralAttention_Adapted, target=MixtralAttention, mode=attention_backend ) if self.gradient_checkpointing and not self.model.gradient_checkpointing: