diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-03-07 09:46:58 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-03-07 09:46:58 +0200 |
commit | 3d85a1d66302989401f92a5ae347577b03cbdaa7 (patch) | |
tree | 7d9ea15568de65954ebddbf71792ad781841fd7f | |
parent | c67a37b251fc22b0f8b8313ea5c76a73ff6ed49f (diff) |
Better FlashMLA (#243)
* This is a better FA for TG
It should benefit MLA and GQA. Tested to work with
DeepSeek-Lite MLA, not yet for GQA.
For tg64@pp8192 it is ~13% faster than MLA without FA,
and 57% faster that the main branch FA.
* WIP
* Cleanup
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/CMakeLists.txt | 4 | ||||
-rw-r--r-- | ggml/src/ggml.c | 112 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_common.h | 138 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_flash_attn.cpp | 194 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_flash_impl.h | 23 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 274 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.h | 16 |
7 files changed, 582 insertions, 179 deletions
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 0ed84956..c1ebd870 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -258,8 +258,8 @@ set (GGML_HEADERS_IQK iqk/iqk_config.h) if (GGML_IQK_MUL_MAT) message(STATUS "Using optimized iqk matrix multiplications") add_compile_definitions(GGML_USE_IQK_MULMAT) - set(GGML_SOURCES_IQK_MM iqk/iqk_mul_mat.cpp) - set(GGML_HEADERS_IQK_MM iqk/iqk_mul_mat.h) + set(GGML_SOURCES_IQK_MM iqk/iqk_mul_mat.cpp iqk/iqk_flash_attn.cpp) + set(GGML_HEADERS_IQK_MM iqk/iqk_mul_mat.h iqk/iqk_flash_impl.h) if (GGML_IQK_FA_ALL_QUANTS) message(STATUS "Including all IQK FA kernels") add_compile_definitions(GGML_IQK_FA_ALL_QUANTS) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 46e1a548..e5ad15f2 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -17870,46 +17870,57 @@ static void ggml_compute_forward_flash_attn_ext_f16( } #if GGML_USE_IQK_MULMAT - if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) { - //if (ith == 0) printf("k: %ld x %ld x %ld, q: %ld x %ld x %ld, v: %ld x %ld x %ld mask: %ld x %ld x %ld\n", - // k->ne[0], k->ne[1], k->ne[2], q->ne[0], q->ne[1], q->ne[2], v->ne[0], v->ne[1], v->ne[2], mask->ne[0], mask->ne[1], mask->ne[2]); - // I keep changing my mind what is the best strategy to split the threads when processing - // multiple heads. This is my current thinking, the commented out code below was the previous. - int ntg = nth/simple_gcd(neq2*neq3, nth); - int64_t neq1g = (neq1 + ntg - 1)/ntg; - //int64_t work_per_slice = D*nek1*neq1; - //int ntg = 1; - // - // When neq1 is large, it is better to have more than one thread process one (iq2,iq3) matrix - // But we also want each thread to process the same amount of rows, so neq1 must be a multiple of - // the number of threads processing the (iq2, iq3) matrix. - // - //if (neq1 >= 8*nth) { - // if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8; - // else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4; - // else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2; - //} - int counter = 0; - for (int64_t iq3 = 0; iq3 < neq3; iq3++) { - for (int64_t iq2 = 0; iq2 < neq2; iq2++) { - if (counter++ % (nth/ntg) == ith/ntg) { - int iq1 = (ith%ntg)*neq1g; - int this_neq1 = MIN(neq1g, neq1-iq1); - if (!iqk_flash_attn_noalibi(k->type, v->type, - Dk, Dv, this_neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float), - (const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]), - (const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]), - (const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]), - (const void *)((const char *)mask->data + iq1*mask->nb[1]), - scale, softcap, - (float *)((char *) dst->data + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1))) goto IQK_Flash_Attn_NotAvailable; - } - } - } - return; -IQK_Flash_Attn_NotAvailable:; - printf("iqk_flash was rejected\n"); - } + if (iqk_flash_attn_noalibi(q->type, mask->type, max_bias, + q->ne[3], q->ne[2], q->nb[3], q->nb[2], + k->ne[3], k->ne[2], k->nb[3], k->nb[2], + v->ne[3], v->ne[2], v->nb[3], v->nb[2], + dst->ne[2], dst->ne[1], dst->nb[1], + k->type, v->type, + Dk, Dv, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], + q->data, k->data, v->data, mask->data, + scale, softcap, (float *)dst->data, + params->wdata, (barrier_t)ggml_barrier, (void *)params->shared, ith, nth)) return; + +// if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) { +// //if (ith == 0) printf("k: %ld x %ld x %ld, q: %ld x %ld x %ld, v: %ld x %ld x %ld mask: %ld x %ld x %ld\n", +// // k->ne[0], k->ne[1], k->ne[2], q->ne[0], q->ne[1], q->ne[2], v->ne[0], v->ne[1], v->ne[2], mask->ne[0], mask->ne[1], mask->ne[2]); +// // I keep changing my mind what is the best strategy to split the threads when processing +// // multiple heads. This is my current thinking, the commented out code below was the previous. +// int ntg = nth/simple_gcd(neq2*neq3, nth); +// int64_t neq1g = (neq1 + ntg - 1)/ntg; +// //int64_t work_per_slice = D*nek1*neq1; +// //int ntg = 1; +// // +// // When neq1 is large, it is better to have more than one thread process one (iq2,iq3) matrix +// // But we also want each thread to process the same amount of rows, so neq1 must be a multiple of +// // the number of threads processing the (iq2, iq3) matrix. +// // +// //if (neq1 >= 8*nth) { +// // if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8; +// // else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4; +// // else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2; +// //} +// int counter = 0; +// for (int64_t iq3 = 0; iq3 < neq3; iq3++) { +// for (int64_t iq2 = 0; iq2 < neq2; iq2++) { +// if (counter++ % (nth/ntg) == ith/ntg) { +// int iq1 = (ith%ntg)*neq1g; +// int this_neq1 = MIN(neq1g, neq1-iq1); +// if (!iqk_flash_attn_noalibi(k->type, v->type, +// Dk, Dv, this_neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float), +// (const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]), +// (const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]), +// (const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]), +// (const void *)((const char *)mask->data + iq1*mask->nb[1]), +// scale, softcap, +// (float *)((char *) dst->data + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1))) goto IQK_Flash_Attn_NotAvailable; +// } +// } +// } +// return; +//IQK_Flash_Attn_NotAvailable:; +// printf("iqk_flash was rejected\n"); +// } #endif const uint32_t n_head = neq2; @@ -21534,6 +21545,27 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa const int64_t D = MAX(Dk, Dv); cur = 3*sizeof(float)*D*n_tasks; // 3x head size/thread +#if GGML_USE_IQK_MULMAT + const struct ggml_tensor * q = node->src[0]; + const struct ggml_tensor * k = node->src[1]; + if (q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1 && n_tasks > 1 && k->ne[1]/32 > 1) { + int nstep_k = k->ne[1]/32; + int gcd_k = simple_gcd(nstep_k, n_tasks); + if (gcd_k > 1) { + int nth_k = n_tasks/gcd_k; + int rk2 = q->ne[2]/k->ne[2]; + if (rk2%nth_k == 0) { + size_t size = (Dv + 16)*rk2/nth_k*sizeof(float)*n_tasks; + if (ggml_is_quantized(k->type)) { + enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type; + size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]); + size += q->ne[2]*row_size; + } + cur = MAX(cur, size); + } + } + } +#endif } break; case GGML_OP_FLASH_ATTN_BACK: { diff --git a/ggml/src/iqk/iqk_common.h b/ggml/src/iqk/iqk_common.h new file mode 100644 index 00000000..dc3e369f --- /dev/null +++ b/ggml/src/iqk/iqk_common.h @@ -0,0 +1,138 @@ +// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*- +// vi: set et ft=cpp fenc=utf-8 :vi +// +// +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + +#include "iqk_config.h" + +#if defined IQK_IMPLEMENT + +#include <cstring> +#include <type_traits> +#include <vector> + +#include "ggml-impl.h" +#include "ggml-quants.h" +#include "iqk_mul_mat.h" +#include "iqk_quantize.h" + +#define GGML_COMMON_IMPL_C +#include "ggml-common.h" + +#define FA_TIMING 0 + +#include <utility> +#include <array> +#if FA_TIMING +#include <chrono> +#include <mutex> +struct Perf { + using TimePoint = std::chrono::time_point<std::chrono::high_resolution_clock>; + std::array<double, 5> times = {}; + std::mutex mutex; + bool report; + static auto cur_time() { return std::chrono::high_resolution_clock::now(); } + inline void accum(int what, const TimePoint& t1) { + auto t2 = cur_time(); + auto dt = delta(t1, t2); + std::lock_guard<std::mutex> lock(mutex); + times[what] += dt; + } + inline void accum_nolock(int what, const TimePoint& t1) { + auto t2 = cur_time(); + auto dt = delta(t1, t2); + times[what] += dt; + } + inline void add(const Perf& other) { + std::lock_guard<std::mutex> lock(mutex); + for (int i = 0; i < int(times.size()); ++i) times[i] += other.times[i]; + } + Perf(bool r) : report(r) {} + ~Perf() { + if (report) { + double tot = 0; + for (auto& t : times) tot += t; + if (!tot) return; + printf("======================= Timing: %g ms in total\n", tot); + for (int i = 0; i < int(times.size()); ++i) { + if (times[i]) { + printf("%d: %g ms -> %g%c\n", i, times[i], 100*times[i]/tot, '%'); + } + } + } + } + static Perf& instance() { + static Perf p(true); + return p; + } + static double delta(const TimePoint& t1, const TimePoint& t2) { + return 1e-6*std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count(); + } +}; +#endif + +#ifdef __AVX2__ +#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) +#endif + +namespace { + +typedef struct { + int32_t i1; + int32_t i2; +} mmid_row_mapping; + +struct DataInfo { + float * s; + const char * cy; + size_t bs; + size_t by; + int cur_y = 0; + int ne11; + const mmid_row_mapping * row_mapping = nullptr; + size_t bs2 = 0; + + inline const char * src1_row(int iy) const { + if (!row_mapping) return cy + (cur_y + iy)*by; + int i11 = row_mapping[cur_y + iy].i1 % ne11; + int i12 = row_mapping[cur_y + iy].i2; + return cy + (i11 + i12*ne11)*by; + } + + inline void store(int ix, int iy, float result) const { + *(dst_row(iy) + ix) = result; + } +#ifdef __AVX__ + inline void store(int ix, int iy, __m128 result) const { + _mm_storeu_ps(dst_row(iy) + ix, result); + } + inline void store(int ix, int iy, __m256 result) const { + _mm256_storeu_ps(dst_row(iy) + ix, result); + } +#endif +#ifdef __AVX512F__ + inline void store(int ix, int iy, __m512 result) const { + _mm512_storeu_ps(dst_row(iy) + ix, result); + } +#endif +#ifdef __ARM_NEON + inline void store(int ix, int iy, float32x4_t result) const { + vst1q_f32(dst_row(iy) + ix, result); + } +#endif + inline float * dst_row(int iy) const { + if (!row_mapping) return s + (cur_y + iy)*bs; + int i12 = row_mapping[cur_y + iy].i2; + int i1 = row_mapping[cur_y + iy].i1; + int i2 = i12; + return s + i1*bs + i2*bs2; + } +}; + +typedef void (*mul_mat_t)(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x); + +#endif diff --git a/ggml/src/iqk/iqk_flash_attn.cpp b/ggml/src/iqk/iqk_flash_attn.cpp new file mode 100644 index 00000000..fecd818b --- /dev/null +++ b/ggml/src/iqk/iqk_flash_attn.cpp @@ -0,0 +1,194 @@ +#include "iqk_config.h" +#include "iqk_mul_mat.h" +#include "iqk_flash_impl.h" + +#ifdef IQK_IMPLEMENT + +#include <algorithm> +#include <cstdio> +#include <vector> +#include <cstdint> +#include <cstring> +#include <cmath> + +namespace { +inline uint32_t simple_gcd(uint32_t a, uint32_t b) { + while (a != b) { + if (a > b) a -= b; + else b -= a; + } + return a; +} +} + +// TODO: get the ggml_type enum here without polution +// +bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, + int neq3, int neq2, long nbq3, long nbq2, + int nek3, int nek2, long nbk3, long nbk2, + int nev3, int nev2, long nbv3, long nbv2, + int ne2, int ne1, long nb1, + int int_type_k, // type of k + int int_type_v, // type of v + int Dk, // K head size + int Dv, // V head size + int neq1, // number of columns in q + int nek1, // number of rows in k + int stride_q, // distance between q columns in bytes + int stride_k, // distance between k rows in bytes + int stride_v, // distance between v rows in bytes + int stride_m, // distance between mask rows (in bytes + const void * q, // q matrix. + const void * k, // k matrix. Assumed to be fp16, nq x nk elements + const void * v, // v matrix. Assumed to be fp16, nq x nk elements + const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements + float scale, // scale applied before softmax + float softcap, // if > 0, a "soft-cap" operation is applied before softmax + float * qkv, // v*softmax(scale*(k*q)) + [[maybe_unused]] void * work_buffer, [[maybe_unused]] barrier_t barrier, [[maybe_unused]] void * barrier_data, + int ith, int nth) { + + if (type_q != 0 || type_mask != 1 || max_bias > 0) return false; + + int rk2 = neq2/nek2; + int rv2 = neq2/nev2; + int rk3 = neq3/nek3; + int rv3 = neq3/nev3; + + // Getting confused all the time about where to load data from and store the results to + // (especially when combining the results from the threads). + // So, for now, making it work just for MLA (nek2 = 1). + // I think it would also speed up things for GQA, but I'm leaving this for another day. + if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nth >= 1 && nek1/32 > 1 && nek2 == 1) { + int nstep_k = nek1/32; + int gcd_k = simple_gcd(nstep_k, nth); + if (gcd_k >= 1) { + int nth_k = nth/gcd_k; + if (rk2%nth_k == 0) { + int ith_k = ith%gcd_k; + int ith_q = ith/gcd_k; + auto kth = (const char *)k + ith_k*(nek1/gcd_k)*stride_k; + auto vth = (const char *)v + ith_k*(nek1/gcd_k)*stride_v; + auto qth = (const char *)q + ith_q*(rk2/nth_k)*nbq2; + auto mth = (const char *)mask + ith_k*(nek1/gcd_k)*sizeof(uint16_t); // we don't have ggml_half available here + auto work = (char *)work_buffer; + + // Each thread will produce a result of size Dv*(rk2/nth_k)*sizeof(float) + // In addition, we need M, S for the rk2/nth_k rows the thread is processing + // => (Dv + 2)*rk2/nth_k*sizeof(float). We use (Dv + 16) instead to make sure threads are not + // writing onto the same cache line. + auto size_thread = (Dv + 16)*rk2/nth_k*sizeof(float); + auto result_buffer = work; + auto work_this_thread = (float *)(result_buffer + ith*size_thread); + if (!iqk_flash_attn_impl(int_type_k, int_type_v, + Dk, Dv, rk2/nth_k, nek1/gcd_k, nbq2, stride_k, stride_v, 0, Dv, //Dk*sizeof(uint16_t), Dv, + (const float *)qth, (const void *)kth, (const void *)vth, (const void *)mth, + scale, softcap, + work_this_thread, work_this_thread + (Dv+0)*rk2/nth_k, work_this_thread + (Dv+1)*rk2/nth_k)) return false; + + barrier(barrier_data); + + // TODO: simdify this + for (int j = ith; j < rk2; j += nth) { + auto Racc = qkv + j*nb1/sizeof(float); + float M = -INFINITY, S = 0; + int jth_q = j/(rk2/nth_k); + int jj = j%(rk2/nth_k); + for (int j1 = 0; j1 < rk2/nth_k; ++j1) { + auto R = (const float *)(result_buffer + (jth_q*(rk2/nth_k) + j1)*size_thread); + auto Mj = R + Dv*rk2/nth_k; + auto Sj = Mj + rk2/nth_k; + R += jj*Dv; + if (Mj[jj] == -INFINITY) continue; + if (Mj[jj] > M) { + if (M == -INFINITY) { + std::memcpy(Racc, R, Dv*sizeof(float)); + S = Sj[jj]; + } else { + float c = exp(M - Mj[jj]); + S = c*S + Sj[jj]; + for (int i = 0; i < Dv; ++i) Racc[i] = c*Racc[i] + R[i]; + } + M = Mj[jj]; + } else { + float c = exp(Mj[jj] - M); + S += c*Sj[jj]; + for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i]; + } + } + float norm = S > 0 ? 1/S : 1; + for (int i = 0; i < Dv; ++i) Racc[i] *= norm; + } + return true; + + } + } + } + + // I keep changing my mind what is the best strategy to split the threads when processing + // multiple heads. This is my current thinking, the commented out code below was the previous. + int ntg = nth/simple_gcd(neq2*neq3, nth); + int neq1g = (neq1 + ntg - 1)/ntg; + //int64_t work_per_slice = D*nek1*neq1; + //int ntg = 1; + // + // When neq1 is large, it is better to have more than one thread process one (iq2,iq3) matrix + // But we also want each thread to process the same amount of rows, so neq1 must be a multiple of + // the number of threads processing the (iq2, iq3) matrix. + // + //if (neq1 >= 8*nth) { + // if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8; + // else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4; + // else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2; + //} + int counter = 0; + for (int64_t iq3 = 0; iq3 < neq3; iq3++) { + for (int64_t iq2 = 0; iq2 < neq2; iq2++) { + if (counter++ % (nth/ntg) == ith/ntg) { + int iq1 = (ith%ntg)*neq1g; + int this_neq1 = std::min(neq1g, neq1-iq1); + if (!iqk_flash_attn_impl(int_type_k, int_type_v, + Dk, Dv, this_neq1, nek1, stride_q, stride_k, stride_v, stride_m, ne1*nb1/sizeof(float), + (const float *)((const char *)q + iq2*nbq2 + iq3*nbq3 + iq1*stride_q), + (const void *)((const char *)k + iq2/rk2*nbk2 + iq3/rk3*nbk3), + (const void *)((const char *)v + iq2/rv2*nbv2 + iq3/rv3*nbv3), + (const void *)((const char *)mask + iq1*stride_m), + scale, softcap, + (float *)((char *)qkv + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1), nullptr, nullptr)) return false; + } + } + } + + return true; +} + +#else + +bool iqk_flash_attn_noalibi([[maybe_unused]] int type_q, [[maybe_unused]] int type_mask, [[maybe_unused]] float max_bias, + [[maybe_unused]] int neq3, [[maybe_unused]] int neq2, [[maybe_unused]] long nbq3, [[maybe_unused]] long nbq2, + [[maybe_unused]] int nek3, [[maybe_unused]] int nek2, [[maybe_unused]] long nbk3, [[maybe_unused]] long nbk2, + [[maybe_unused]] int nev3, [[maybe_unused]] int nev2, [[maybe_unused]] long nbv3, [[maybe_unused]] long nbv2, + [[maybe_unused]] int ne2, [[maybe_unused]] int ne1, [[maybe_unused]] long nb1, + [[maybe_unused]] int int_type_k, // type of k + [[maybe_unused]] int int_type_v, // type of v + [[maybe_unused]] int D, // head size + [[maybe_unused]] int nq, // number of columns in q + [[maybe_unused]] int nk, // number of rows in k + [[maybe_unused]] int stride_q, // distance between q columns in bytes + [[maybe_unused]] int stride_k, // distance between k rows in bytes + [[maybe_unused]] int stride_v, // distance between v rows in bytes + [[maybe_unused]] int stride_m, // distance between mask rows (in bytes + [[maybe_unused]] const void * q, // q matrix. + [[maybe_unused]] const void * k, // k matrix. Assumed to be fp16, nq x nk elements + [[maybe_unused]] const void * v, // v matrix. Assumed to be fp16, nq x nk elements + [[maybe_unused]] const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements + [[maybe_unused]] float scale, // scale applied before softmax + [[maybe_unused]] float softcap, // if > 0, a "soft-cap" operation is applied before softmax + [[maybe_unused]] float * qkv, // v*softmax(scale*(k*q)) + [[maybe_unused]] void * work_buffer, [[maybe_unused]] barrier_t barrier, [[maybe_unused]] void * barrier_data, + [[maybe_unused]] int ith, [[maybe_unused]] int nth) { + return false; +} + +#endif + diff --git a/ggml/src/iqk/iqk_flash_impl.h b/ggml/src/iqk/iqk_flash_impl.h new file mode 100644 index 00000000..bc68e0d8 --- /dev/null +++ b/ggml/src/iqk/iqk_flash_impl.h @@ -0,0 +1,23 @@ +#pragma once + +bool iqk_flash_attn_impl(int type_k, // type of k + int type_v, // type of v + int Dk, // K head size + int Dv, // V head size + int nq, // number of columns in q + int nk, // number of rows in k + int stride_q, // distance between q columns in bytes + int stride_k, // distance between k rows in bytes + int stride_v, // distance between v rows in bytes + int stride_m, // distance between mask rows (in bytes + int stride_qkv, // distance between rows in mask (in bytes) + const float * q, // q matrix. + const void * k, // k matrix. Assumed to be fp16, nq x nk elements + const void * v, // v matrix. Assumed to be fp16, nq x nk elements + const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements + float scale, // scale applied before softmax + float softcap, // if > 0, a "soft-cap" operation is applied before softmax + float * qkv, // v*softmax(scale*(k*q)) + float * M, + float * S); + diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 1f18837c..14cc64db 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -15870,23 +15870,23 @@ struct FlashMS { inline void update_M_S(int j, float32x4_t * vk) { float smax = load_and_scale(j, vk); update_M(j, smax); - update_S(j, vk); + if (M[j] > -INFINITY) update_S(j, vk); } inline void update_M_S(int j, float32x4_t * vk, const char * mask) { float smax = load_apply_mask_and_scale(j, vk, mask); update_M(j, smax); - update_S(j, vk); + if (M[j] > -INFINITY) update_S(j, vk); } #else inline void update_M_S(int j, F16::Data * vk) { float smax = load_and_scale(j, vk); update_M(j, smax); - update_S(j, vk); + if (M[j] > -INFINITY) update_S(j, vk); } inline void update_M_S(int j, F16::Data * vk, const char * mask) { float smax = load_apply_mask_and_scale(j, vk, mask); update_M(j, smax); - update_S(j, vk); + if (M[j] > -INFINITY) update_S(j, vk); } #endif @@ -16037,7 +16037,7 @@ struct FlashQKV { } } - inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int j, const qkv_cache_t * R, float * qkv) const { + inline void normalize_and_store_1row(const FlashMS<q_step, k_step>& fms, int j, const qkv_cache_t * R, float * qkv) const { GGML_ASSERT(fms.S[j] > 0); auto norm = F16::set1(1/fms.S[j]); //auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f); @@ -16047,21 +16047,43 @@ struct FlashQKV { } } - inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int nq1, int stride_qkv, float * qkv) const { - auto R = qkv_cache; - for (int j = 0; j < nq1; ++j) { - normalize_and_store(fms, j, R, qkv); - qkv += stride_qkv; - R += D; + inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int nq1, int stride_qkv, float * qkv, float * M, float * S) const { + if (M && S) { + std::memcpy(M, fms.M, nq1*sizeof(float)); + std::memcpy(S, fms.S, nq1*sizeof(float)); + auto R = qkv_cache; + for (int j = 0; j < nq1; ++j) { + std::memcpy(qkv, R, D*sizeof(float)); + qkv += stride_qkv; + R += D; + } + } else { + auto R = qkv_cache; + for (int j = 0; j < nq1; ++j) { + normalize_and_store_1row(fms, j, R, qkv); + qkv += stride_qkv; + R += D; + } } } - inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int stride_qkv, float * qkv) const { - auto R = qkv_cache; - for (int j = 0; j < q_step; ++j) { - normalize_and_store(fms, j, R, qkv); - qkv += stride_qkv; - R += D; + inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int stride_qkv, float * qkv, float * M, float * S) const { + if (M && S) { + std::memcpy(M, fms.M, q_step*sizeof(float)); + std::memcpy(S, fms.S, q_step*sizeof(float)); + auto R = qkv_cache; + for (int j = 0; j < q_step; ++j) { + std::memcpy(qkv, R, D*sizeof(float)); + qkv += stride_qkv; + R += D; + } + } else { + auto R = qkv_cache; + for (int j = 0; j < q_step; ++j) { + normalize_and_store_1row(fms, j, R, qkv); + qkv += stride_qkv; + R += D; + } } } @@ -16435,7 +16457,8 @@ template <int Dk, int Dv, int q_step, int k_step, typename KHelper, typename VHe void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, FlashMS<q_step, k_step>& fms, FlashQKV<Dv, q_step, k_step>& fqkv, - const float * q, const char * mask, float * qkv) { + const float * q, const char * mask, float * qkv, + float * M, float * S) { #ifdef __aarch64__ float16_t q_f16[Dk*q_step]; #endif @@ -16459,11 +16482,12 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in vh.next_block(); mr += k_step*sizeof(ggml_half); } - fqkv.normalize_and_store(fms, stride_qkv, qkv); + fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); q += q_step*stride_q; mask += q_step*stride_m; qkv += q_step*stride_qkv; + if (M && S) { M += q_step; S += q_step; } } int n_left = nq1 - q_step*(nq1/q_step); if (n_left > 0) { @@ -16485,7 +16509,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in vh.next_block(); mr += k_step*sizeof(ggml_half); } - fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv); + fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S); } } @@ -16493,7 +16517,8 @@ template <int Dk, int Dv, int q_step, int k_step, typename KHelper, typename VHe void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, FlashMS<q_step, k_step>& fms, FlashQKV<Dv, q_step, k_step>& fqkv, - const float * q, const char * mask, float * qkv) { + const float * q, const char * mask, float * qkv, + float * M, float * S) { typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)]; #if FA_TIMING Perf perf(false); @@ -16528,15 +16553,16 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, } #if FA_TIMING t1 = Perf::cur_time(); - fqkv.normalize_and_store(fms, stride_qkv, qkv); + fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); perf.accum_nolock(3, t1); #else - fqkv.normalize_and_store(fms, stride_qkv, qkv); + fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); #endif q += q_step*stride_q; mask += q_step*stride_m; qkv += q_step*stride_qkv; + if (M && S) { M += q_step; S += q_step; } } int n_left = nq1 - q_step*(nq1/q_step); if (n_left > 0) { @@ -16552,7 +16578,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, vh.next_block(); mr += k_step*sizeof(ggml_half); } - fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv); + fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S); } #if FA_TIMING Perf::instance().add(perf); @@ -16580,12 +16606,12 @@ struct FlashAttn { template <typename KHelper, typename VHelper> void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, - const float * q, const char * mask, float * qkv) { + const float * q, const char * mask, float * qkv, [[maybe_unused]] float * M, [[maybe_unused]] float * S) { if constexpr (std::is_same_v<KHelper, HelperQ40<Dk, k_step>> || std::is_same_v<KHelper, HelperQ41<Dk, k_step>> || std::is_same_v<KHelper, HelperIQ4nl<Dk, k_step>> || std::is_same_v<KHelper, HelperQ60<Dk, k_step>>) { compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>( - kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); + kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); } else if constexpr (std::is_same_v<KHelper, HelperQ80<Dk, k_step>>) { if (nq1 >= 8) { @@ -16597,10 +16623,10 @@ struct FlashAttn { HelperQ80R8<Dk, k_step> khr4(nk1, kh); #endif compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>( - khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); + khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); } else{ compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>( - kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); + kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); } } else if constexpr (std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>>) { @@ -16613,14 +16639,14 @@ struct FlashAttn { HelperQ8KVR8<Dk, k_step> khr4(nk1, kh); #endif compute_helper_q<Dk, Dv, q_step, k_step, HelperQ8KVR8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>( - khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); + khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); } else{ compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>( - kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); + kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); } } else { compute_helper<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>( - kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); + kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); } } @@ -16987,7 +17013,7 @@ struct FlashAttnBF16 { template <typename KHelper, typename VHelper> void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, - const float * q, const char * mask, float * qkv) { + const float * q, const char * mask, float * qkv, [[maybe_unused]] float * M, [[maybe_unused]] float * S) { ggml_bf16_t q_bf16[q_step*Dk]; #if FA_TIMING Perf perf(false); @@ -17023,7 +17049,7 @@ struct FlashAttnBF16 { #if FA_TIMING t1 = Perf::cur_time(); #endif - fqkv.normalize_and_store(fms, stride_qkv, qkv); + fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); #if FA_TIMING perf.accum_nolock(4, t1); #endif @@ -17046,7 +17072,7 @@ struct FlashAttnBF16 { vh.next_block(); mr += k_step*sizeof(ggml_half); } - fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv); + fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S); } #if FA_TIMING Perf::instance().add(perf); @@ -17060,32 +17086,32 @@ struct FlashAttnBF16 { template <int Dk, int Dv, int k_step, typename KHelper, typename VHelper> inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, - const float * q, const char * mask, float scale, float softcap, float * qkv) { + const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) { if (nk1 >= 256) { //4096) { if (nq1 >= 64) { FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); return; } if (nq1 >= 32) { FlashAttn<Dk, Dv, 32, k_step> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); return; } if (nq1 >= 16) { FlashAttn<Dk, Dv, 16, k_step> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); return; } } if (nq1 >= 8) { FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); } else { FlashAttn<Dk, Dv, 1, k_step> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); } } @@ -17093,27 +17119,27 @@ inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int str template <int Dk, int Dv, int k_step> inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, const float * q, const char * k, const char * v, const char * mask, - float scale, float softcap, float * qkv) { + float scale, float softcap, float * qkv, float * M, float * S) { HelperBF16<Dk, k_step> kh(k, stride_k); HelperBF16<Dv, k_step> vh(v, stride_v); if (nk1 >= 4096) { if (nq1 >= 64) { FlashAttnBF16<Dk, Dv, 64, k_step> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); return; } else if (nq1 >= 16) { FlashAttnBF16<Dk, Dv, 16, k_step> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); return; } } if (nq1 >= 8) { FlashAttnBF16<Dk, Dv, 8, k_step> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); } else { FlashAttnBF16<Dk, Dv, 1, k_step> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); } } #endif @@ -17122,43 +17148,43 @@ template <int Dk, int Dv, int k_step, typename KHelper> inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, int nq1, int nk1, int stride_q, int stride_v, int stride_m, int stride_qkv, const float * q, const char * v, const char * mask, - float scale, float softcap, float * qkv) { + float scale, float softcap, float * qkv, float * M, float * S) { switch (type_v) { case GGML_TYPE_F16: { HelperF16<Dv, k_step> vh(v, stride_v); - iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; #ifdef __AVX512BF16__ case GGML_TYPE_BF16: { HelperBF16<Dv, k_step> vh(v, stride_v); - iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; #endif case GGML_TYPE_Q8_0: { HelperQ80<Dv, k_step> vh(v, stride_v); - iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; case GGML_TYPE_Q8_KV: { HelperQ8KV<Dv, k_step> vh(v, stride_v); - iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; case GGML_TYPE_Q6_0: { HelperQ60<Dv, k_step> vh(v, stride_v); - iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; #if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_0: { HelperQ40<Dv, k_step> vh(v, stride_v); - iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; case GGML_TYPE_Q4_1: { HelperQ41<Dv, k_step> vh(v, stride_v); - iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; case GGML_TYPE_IQ4_NL: { HelperIQ4nl<Dv, k_step> vh(v, stride_v); - iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; #endif default: break; @@ -17169,37 +17195,37 @@ template <int Dk, int Dv, int k_step> inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, const float * q, const char * k, const char * v, const char * mask, - float scale, float softcap, float * qkv) { + float scale, float softcap, float * qkv, float * M, float * S) { switch (type_k) { case GGML_TYPE_F16: { HelperF16<Dk, k_step> kh(k, stride_k); - iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; case GGML_TYPE_Q8_0: { HelperQ80<Dk, k_step> kh(k, stride_k); - iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; case GGML_TYPE_Q8_KV: { HelperQ8KV<Dk, k_step> kh(k, stride_k); - iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; case GGML_TYPE_Q6_0: { HelperQ60<Dk, k_step> kh(k, stride_k); - iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; #if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_0: { HelperQ40<Dk, k_step> kh(k, stride_k); - iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; case GGML_TYPE_Q4_1: { HelperQ41<Dk, k_step> kh(k, stride_k); - iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; case GGML_TYPE_IQ4_NL: { HelperIQ4nl<Dk, k_step> kh(k, stride_k); - iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; #endif default: break; @@ -17223,13 +17249,13 @@ inline bool flash_attn_is_supported(ggml_type type) { template <int step_k, typename KHelper, typename VHelper> inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, - const float * q, const char * mask, float scale, float softcap, float * qkv) { + const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) { if (nq1 % 8 == 0) { FlashAttn<576, 512, 8, step_k> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); } else { FlashAttn<576, 512, 1, step_k> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); } } @@ -17237,29 +17263,29 @@ template <int step_k> inline bool iqk_deepseek_helper(ggml_type type_k, int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, const float * q, const char * k, const char * v, const char * mask, - float scale, float softcap, float * qkv) { + float scale, float softcap, float * qkv, float * M, float * S) { if (type_k == GGML_TYPE_Q8_0) { HelperQ80<576, step_k> kh((const char *)k, stride_k); HelperQ80<512, step_k> vh((const char *)v, stride_v); - iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); return true; } if (type_k == GGML_TYPE_Q6_0) { HelperQ60<576, step_k> kh((const char *)k, stride_k); HelperQ60<512, step_k> vh((const char *)v, stride_v); - iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); return true; } if (type_k == GGML_TYPE_Q8_KV) { HelperQ8KV<576, step_k> kh((const char *)k, stride_k); HelperQ8KV<512, step_k> vh((const char *)v, stride_v); - iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); return true; } if (type_k == GGML_TYPE_F16) { HelperF16<576, step_k> kh((const char *)k, stride_k); HelperF16<512, step_k> vh((const char *)v, stride_v); - iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); return true; } #ifdef __AVX512BF16__ @@ -17268,10 +17294,10 @@ inline bool iqk_deepseek_helper(ggml_type type_k, HelperBF16<512, step_k> vh((const char *)v, stride_v); if (nq1 % 8 == 0) { FlashAttnBF16<576, 512, 8, step_k> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); } else { FlashAttnBF16<576, 512, 1, step_k> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); } return true; } @@ -17281,24 +17307,27 @@ inline bool iqk_deepseek_helper(ggml_type type_k, } -bool iqk_flash_attn_noalibi(int int_type_k, // type of k - int int_type_v, // type of v - int Dk, // K head size - int Dv, // V head size - int nq1, // number of columns in q - int nk1, // number of rows in k - int stride_q, // distance between q columns in bytes - int stride_k, // distance between k rows in bytes - int stride_v, // distance between v rows in bytes - int stride_m, // distance between mask rows (in bytes - int stride_qkv, // distance between rows in mask (in bytes) - const float * q, // q matrix. - const void * k, // k matrix. Assumed to be fp16, nq x nk elements - const void * v, // v matrix. Assumed to be fp16, nq x nk elements - const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements - float scale, // scale applied before softmax - float softcap, // if > 0, a "soft-cap" operation is applied before softmax - float * qkv) { // v*softmax(scale*(k*q)) +#include "iqk_flash_impl.h" + +bool iqk_flash_attn_impl(int int_type_k, // type of k + int int_type_v, // type of v + int Dk, // K head size + int Dv, // V head size + int nq1, // number of columns in q + int nk1, // number of rows in k + int stride_q, // distance between q columns in bytes + int stride_k, // distance between k rows in bytes + int stride_v, // distance between v rows in bytes + int stride_m, // distance between mask rows (in bytes + int stride_qkv, // distance between rows in mask (in bytes) + const float * q, // q matrix. + const void * k, // k matrix. Assumed to be fp16, nq x nk elements + const void * v, // v matrix. Assumed to be fp16, nq x nk elements + const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements + float scale, // scale applied before softmax + float softcap, // if > 0, a "soft-cap" operation is applied before softmax + float * qkv, // v*softmax(scale*(k*q)) + float * M, float * S) { if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32 @@ -17309,13 +17338,13 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k GGML_ASSERT(type_k == type_v); stride_q /= sizeof(float); // q stride as float return iqk_deepseek_helper<32>(type_k, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv); + q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv, M, S); } if (!flash_attn_is_supported(type_k) || !flash_attn_is_supported(type_v)) return false; if (Dk != Dv && Dk != 192 && Dv != 128) return false; if (Dv != 64 && Dv != 96 && Dv != 128 && Dv != 256) return false; - if (Dk != 64 && Dk != 96 && Dk != 128 && Dk != 192 && Dv != 256) return false; + if (Dk != 64 && Dk != 96 && Dk != 128 && Dk != 192 && Dk != 256) return false; auto ck = (const char *)k; auto cv = (const char *)v; @@ -17329,15 +17358,15 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types switch (Dk) { case 64: - iqk_flash_helper_T< 64, 64, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 64, 64, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 96: - iqk_flash_helper_T< 96, 96, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 96, 96, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 128: - iqk_flash_helper_T<128, 128, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<128, 128, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 192: - iqk_flash_helper_T<192, 128, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<192, 128, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 256: - iqk_flash_helper_T<256, 256, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<256, 256, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; default: return false; } @@ -17346,15 +17375,15 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types switch (Dk) { case 64: - iqk_flash_helper_T< 64, 64, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 64, 64, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 96: - iqk_flash_helper_T< 96, 96, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 96, 96, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 128: - iqk_flash_helper_T<128, 128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<128, 128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 192: - iqk_flash_helper_T<192, 128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<192, 128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 256: - iqk_flash_helper_T<256, 256, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<256, 256, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; default: return false; } @@ -17366,21 +17395,21 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k if (nk1%64 == 0) { switch (Dk) { case 64: - iqk_flash_helper_T< 64, 64, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 64, 64, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; // Disable until we fix accumulate_qkv for odd D/16 //case 80: // iqk_flash_helper_T< 80, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 96: - iqk_flash_helper_T< 96, 96, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 96, 96, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; // Disable until we fix accumulate_qkv for odd D/16 //case 112: // iqk_flash_helper_T<112, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 128: - iqk_flash_helper_T<128, 128, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<128, 128, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 192: - iqk_flash_helper_T<192, 128, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<192, 128, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 256: - iqk_flash_helper_T<256, 256, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<256, 256, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; default: return false; } @@ -17388,21 +17417,21 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k } switch (Dk) { case 64: - iqk_flash_helper_T< 64, 64, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 64, 64, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; // Disable until we fix accumulate_qkv for odd D/16 //case 80: // iqk_flash_helper_T< 80, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 96: - iqk_flash_helper_T< 96, 96, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 96, 96, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; // Disable until we fix accumulate_qkv for odd D/16 //case 112: // iqk_flash_helper_T<112, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 128: - iqk_flash_helper_T<128, 128, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<128, 128, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 192: - iqk_flash_helper_T<192, 128, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<192, 128, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 256: - iqk_flash_helper_T<256, 256, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<256, 256, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; default: return false; } @@ -17437,25 +17466,4 @@ bool iqk_moe_fused_up_gate(long /*Nx*/, long /*Ny*/, long /*ne00*/, int /*ne11*/ return false; } - -bool iqk_flash_attn_noalibi([[maybe_unused]] int int_type_k, // type of k - [[maybe_unused]] int int_type_v, // type of v - [[maybe_unused]] int D, // head size - [[maybe_unused]] int nq, // number of columns in q - [[maybe_unused]] int nk, // number of rows in k - [[maybe_unused]] int stride_q, // distance between q columns in bytes - [[maybe_unused]] int stride_k, // distance between k rows in bytes - [[maybe_unused]] int stride_v, // distance between v rows in bytes - [[maybe_unused]] int stride_m, // distance between mask rows (in bytes - [[maybe_unused]] int stride_qkv, // distance between rows in mask (in bytes) - [[maybe_unused]] const float * q, // q matrix. - [[maybe_unused]] const void * k, // k matrix. Assumed to be fp16, nq x nk elements - [[maybe_unused]] const void * v, // v matrix. Assumed to be fp16, nq x nk elements - [[maybe_unused]] const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements - [[maybe_unused]] float scale, // scale applied before softmax - [[maybe_unused]] float softcap, // if > 0, a "soft-cap" operation is applied before softmax - [[maybe_unused]] float * qkv) { // v*softmax(scale*(k*q)) - return false; -} - #endif diff --git a/ggml/src/iqk/iqk_mul_mat.h b/ggml/src/iqk/iqk_mul_mat.h index 767f89cf..d91c4710 100644 --- a/ggml/src/iqk/iqk_mul_mat.h +++ b/ggml/src/iqk/iqk_mul_mat.h @@ -33,7 +33,14 @@ bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op, int typeB, const void * B, long strideB, float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth); -bool iqk_flash_attn_noalibi(int type_k, // type of k +typedef void (*barrier_t) (void *); + +bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, + int neq3, int neq2, long nbq3, long nbq2, + int nek3, int nek2, long nbk3, long nbk2, + int nev3, int nev2, long nbv3, long nbv2, + int ne2, int ne1, long nb1, + int type_k, // type of k int type_v, // type of v int Dk, // K head size int Dv, // V head size @@ -43,14 +50,15 @@ bool iqk_flash_attn_noalibi(int type_k, // type of k int stride_k, // distance between k rows in bytes int stride_v, // distance between v rows in bytes int stride_m, // distance between mask rows (in bytes - int stride_qkv, // distance between rows in mask (in bytes) - const float * q, // q matrix. + const void * q, // q matrix. const void * k, // k matrix. Assumed to be fp16, nq x nk elements const void * v, // v matrix. Assumed to be fp16, nq x nk elements const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements float scale, // scale applied before softmax float softcap, // if > 0, a "soft-cap" operation is applied before softmax - float * qkv); // v*softmax(scale*(k*q)) + float * qkv, // v*softmax(scale*(k*q)) + void * work_buffer, barrier_t barrier, void * barrier_data, + int ith, int nth); #ifdef __cplusplus } |