summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-03-07 09:46:58 +0200
committerGitHub <noreply@github.com>2025-03-07 09:46:58 +0200
commit3d85a1d66302989401f92a5ae347577b03cbdaa7 (patch)
tree7d9ea15568de65954ebddbf71792ad781841fd7f
parentc67a37b251fc22b0f8b8313ea5c76a73ff6ed49f (diff)
Better FlashMLA (#243)
* This is a better FA for TG It should benefit MLA and GQA. Tested to work with DeepSeek-Lite MLA, not yet for GQA. For tg64@pp8192 it is ~13% faster than MLA without FA, and 57% faster that the main branch FA. * WIP * Cleanup --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/CMakeLists.txt4
-rw-r--r--ggml/src/ggml.c112
-rw-r--r--ggml/src/iqk/iqk_common.h138
-rw-r--r--ggml/src/iqk/iqk_flash_attn.cpp194
-rw-r--r--ggml/src/iqk/iqk_flash_impl.h23
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp274
-rw-r--r--ggml/src/iqk/iqk_mul_mat.h16
7 files changed, 582 insertions, 179 deletions
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
index 0ed84956..c1ebd870 100644
--- a/ggml/src/CMakeLists.txt
+++ b/ggml/src/CMakeLists.txt
@@ -258,8 +258,8 @@ set (GGML_HEADERS_IQK iqk/iqk_config.h)
if (GGML_IQK_MUL_MAT)
message(STATUS "Using optimized iqk matrix multiplications")
add_compile_definitions(GGML_USE_IQK_MULMAT)
- set(GGML_SOURCES_IQK_MM iqk/iqk_mul_mat.cpp)
- set(GGML_HEADERS_IQK_MM iqk/iqk_mul_mat.h)
+ set(GGML_SOURCES_IQK_MM iqk/iqk_mul_mat.cpp iqk/iqk_flash_attn.cpp)
+ set(GGML_HEADERS_IQK_MM iqk/iqk_mul_mat.h iqk/iqk_flash_impl.h)
if (GGML_IQK_FA_ALL_QUANTS)
message(STATUS "Including all IQK FA kernels")
add_compile_definitions(GGML_IQK_FA_ALL_QUANTS)
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 46e1a548..e5ad15f2 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -17870,46 +17870,57 @@ static void ggml_compute_forward_flash_attn_ext_f16(
}
#if GGML_USE_IQK_MULMAT
- if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) {
- //if (ith == 0) printf("k: %ld x %ld x %ld, q: %ld x %ld x %ld, v: %ld x %ld x %ld mask: %ld x %ld x %ld\n",
- // k->ne[0], k->ne[1], k->ne[2], q->ne[0], q->ne[1], q->ne[2], v->ne[0], v->ne[1], v->ne[2], mask->ne[0], mask->ne[1], mask->ne[2]);
- // I keep changing my mind what is the best strategy to split the threads when processing
- // multiple heads. This is my current thinking, the commented out code below was the previous.
- int ntg = nth/simple_gcd(neq2*neq3, nth);
- int64_t neq1g = (neq1 + ntg - 1)/ntg;
- //int64_t work_per_slice = D*nek1*neq1;
- //int ntg = 1;
- //
- // When neq1 is large, it is better to have more than one thread process one (iq2,iq3) matrix
- // But we also want each thread to process the same amount of rows, so neq1 must be a multiple of
- // the number of threads processing the (iq2, iq3) matrix.
- //
- //if (neq1 >= 8*nth) {
- // if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8;
- // else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4;
- // else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2;
- //}
- int counter = 0;
- for (int64_t iq3 = 0; iq3 < neq3; iq3++) {
- for (int64_t iq2 = 0; iq2 < neq2; iq2++) {
- if (counter++ % (nth/ntg) == ith/ntg) {
- int iq1 = (ith%ntg)*neq1g;
- int this_neq1 = MIN(neq1g, neq1-iq1);
- if (!iqk_flash_attn_noalibi(k->type, v->type,
- Dk, Dv, this_neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
- (const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]),
- (const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]),
- (const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]),
- (const void *)((const char *)mask->data + iq1*mask->nb[1]),
- scale, softcap,
- (float *)((char *) dst->data + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1))) goto IQK_Flash_Attn_NotAvailable;
- }
- }
- }
- return;
-IQK_Flash_Attn_NotAvailable:;
- printf("iqk_flash was rejected\n");
- }
+ if (iqk_flash_attn_noalibi(q->type, mask->type, max_bias,
+ q->ne[3], q->ne[2], q->nb[3], q->nb[2],
+ k->ne[3], k->ne[2], k->nb[3], k->nb[2],
+ v->ne[3], v->ne[2], v->nb[3], v->nb[2],
+ dst->ne[2], dst->ne[1], dst->nb[1],
+ k->type, v->type,
+ Dk, Dv, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1],
+ q->data, k->data, v->data, mask->data,
+ scale, softcap, (float *)dst->data,
+ params->wdata, (barrier_t)ggml_barrier, (void *)params->shared, ith, nth)) return;
+
+// if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) {
+// //if (ith == 0) printf("k: %ld x %ld x %ld, q: %ld x %ld x %ld, v: %ld x %ld x %ld mask: %ld x %ld x %ld\n",
+// // k->ne[0], k->ne[1], k->ne[2], q->ne[0], q->ne[1], q->ne[2], v->ne[0], v->ne[1], v->ne[2], mask->ne[0], mask->ne[1], mask->ne[2]);
+// // I keep changing my mind what is the best strategy to split the threads when processing
+// // multiple heads. This is my current thinking, the commented out code below was the previous.
+// int ntg = nth/simple_gcd(neq2*neq3, nth);
+// int64_t neq1g = (neq1 + ntg - 1)/ntg;
+// //int64_t work_per_slice = D*nek1*neq1;
+// //int ntg = 1;
+// //
+// // When neq1 is large, it is better to have more than one thread process one (iq2,iq3) matrix
+// // But we also want each thread to process the same amount of rows, so neq1 must be a multiple of
+// // the number of threads processing the (iq2, iq3) matrix.
+// //
+// //if (neq1 >= 8*nth) {
+// // if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8;
+// // else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4;
+// // else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2;
+// //}
+// int counter = 0;
+// for (int64_t iq3 = 0; iq3 < neq3; iq3++) {
+// for (int64_t iq2 = 0; iq2 < neq2; iq2++) {
+// if (counter++ % (nth/ntg) == ith/ntg) {
+// int iq1 = (ith%ntg)*neq1g;
+// int this_neq1 = MIN(neq1g, neq1-iq1);
+// if (!iqk_flash_attn_noalibi(k->type, v->type,
+// Dk, Dv, this_neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
+// (const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]),
+// (const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]),
+// (const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]),
+// (const void *)((const char *)mask->data + iq1*mask->nb[1]),
+// scale, softcap,
+// (float *)((char *) dst->data + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1))) goto IQK_Flash_Attn_NotAvailable;
+// }
+// }
+// }
+// return;
+//IQK_Flash_Attn_NotAvailable:;
+// printf("iqk_flash was rejected\n");
+// }
#endif
const uint32_t n_head = neq2;
@@ -21534,6 +21545,27 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
const int64_t D = MAX(Dk, Dv);
cur = 3*sizeof(float)*D*n_tasks; // 3x head size/thread
+#if GGML_USE_IQK_MULMAT
+ const struct ggml_tensor * q = node->src[0];
+ const struct ggml_tensor * k = node->src[1];
+ if (q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1 && n_tasks > 1 && k->ne[1]/32 > 1) {
+ int nstep_k = k->ne[1]/32;
+ int gcd_k = simple_gcd(nstep_k, n_tasks);
+ if (gcd_k > 1) {
+ int nth_k = n_tasks/gcd_k;
+ int rk2 = q->ne[2]/k->ne[2];
+ if (rk2%nth_k == 0) {
+ size_t size = (Dv + 16)*rk2/nth_k*sizeof(float)*n_tasks;
+ if (ggml_is_quantized(k->type)) {
+ enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type;
+ size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]);
+ size += q->ne[2]*row_size;
+ }
+ cur = MAX(cur, size);
+ }
+ }
+ }
+#endif
} break;
case GGML_OP_FLASH_ATTN_BACK:
{
diff --git a/ggml/src/iqk/iqk_common.h b/ggml/src/iqk/iqk_common.h
new file mode 100644
index 00000000..dc3e369f
--- /dev/null
+++ b/ggml/src/iqk/iqk_common.h
@@ -0,0 +1,138 @@
+// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
+// vi: set et ft=cpp fenc=utf-8 :vi
+//
+//
+// Copyright (C) 2024 Iwan Kawrakow
+// MIT license
+// SPDX-License-Identifier: MIT
+//
+
+#include "iqk_config.h"
+
+#if defined IQK_IMPLEMENT
+
+#include <cstring>
+#include <type_traits>
+#include <vector>
+
+#include "ggml-impl.h"
+#include "ggml-quants.h"
+#include "iqk_mul_mat.h"
+#include "iqk_quantize.h"
+
+#define GGML_COMMON_IMPL_C
+#include "ggml-common.h"
+
+#define FA_TIMING 0
+
+#include <utility>
+#include <array>
+#if FA_TIMING
+#include <chrono>
+#include <mutex>
+struct Perf {
+ using TimePoint = std::chrono::time_point<std::chrono::high_resolution_clock>;
+ std::array<double, 5> times = {};
+ std::mutex mutex;
+ bool report;
+ static auto cur_time() { return std::chrono::high_resolution_clock::now(); }
+ inline void accum(int what, const TimePoint& t1) {
+ auto t2 = cur_time();
+ auto dt = delta(t1, t2);
+ std::lock_guard<std::mutex> lock(mutex);
+ times[what] += dt;
+ }
+ inline void accum_nolock(int what, const TimePoint& t1) {
+ auto t2 = cur_time();
+ auto dt = delta(t1, t2);
+ times[what] += dt;
+ }
+ inline void add(const Perf& other) {
+ std::lock_guard<std::mutex> lock(mutex);
+ for (int i = 0; i < int(times.size()); ++i) times[i] += other.times[i];
+ }
+ Perf(bool r) : report(r) {}
+ ~Perf() {
+ if (report) {
+ double tot = 0;
+ for (auto& t : times) tot += t;
+ if (!tot) return;
+ printf("======================= Timing: %g ms in total\n", tot);
+ for (int i = 0; i < int(times.size()); ++i) {
+ if (times[i]) {
+ printf("%d: %g ms -> %g%c\n", i, times[i], 100*times[i]/tot, '%');
+ }
+ }
+ }
+ }
+ static Perf& instance() {
+ static Perf p(true);
+ return p;
+ }
+ static double delta(const TimePoint& t1, const TimePoint& t2) {
+ return 1e-6*std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count();
+ }
+};
+#endif
+
+#ifdef __AVX2__
+#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
+#endif
+
+namespace {
+
+typedef struct {
+ int32_t i1;
+ int32_t i2;
+} mmid_row_mapping;
+
+struct DataInfo {
+ float * s;
+ const char * cy;
+ size_t bs;
+ size_t by;
+ int cur_y = 0;
+ int ne11;
+ const mmid_row_mapping * row_mapping = nullptr;
+ size_t bs2 = 0;
+
+ inline const char * src1_row(int iy) const {
+ if (!row_mapping) return cy + (cur_y + iy)*by;
+ int i11 = row_mapping[cur_y + iy].i1 % ne11;
+ int i12 = row_mapping[cur_y + iy].i2;
+ return cy + (i11 + i12*ne11)*by;
+ }
+
+ inline void store(int ix, int iy, float result) const {
+ *(dst_row(iy) + ix) = result;
+ }
+#ifdef __AVX__
+ inline void store(int ix, int iy, __m128 result) const {
+ _mm_storeu_ps(dst_row(iy) + ix, result);
+ }
+ inline void store(int ix, int iy, __m256 result) const {
+ _mm256_storeu_ps(dst_row(iy) + ix, result);
+ }
+#endif
+#ifdef __AVX512F__
+ inline void store(int ix, int iy, __m512 result) const {
+ _mm512_storeu_ps(dst_row(iy) + ix, result);
+ }
+#endif
+#ifdef __ARM_NEON
+ inline void store(int ix, int iy, float32x4_t result) const {
+ vst1q_f32(dst_row(iy) + ix, result);
+ }
+#endif
+ inline float * dst_row(int iy) const {
+ if (!row_mapping) return s + (cur_y + iy)*bs;
+ int i12 = row_mapping[cur_y + iy].i2;
+ int i1 = row_mapping[cur_y + iy].i1;
+ int i2 = i12;
+ return s + i1*bs + i2*bs2;
+ }
+};
+
+typedef void (*mul_mat_t)(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x);
+
+#endif
diff --git a/ggml/src/iqk/iqk_flash_attn.cpp b/ggml/src/iqk/iqk_flash_attn.cpp
new file mode 100644
index 00000000..fecd818b
--- /dev/null
+++ b/ggml/src/iqk/iqk_flash_attn.cpp
@@ -0,0 +1,194 @@
+#include "iqk_config.h"
+#include "iqk_mul_mat.h"
+#include "iqk_flash_impl.h"
+
+#ifdef IQK_IMPLEMENT
+
+#include <algorithm>
+#include <cstdio>
+#include <vector>
+#include <cstdint>
+#include <cstring>
+#include <cmath>
+
+namespace {
+inline uint32_t simple_gcd(uint32_t a, uint32_t b) {
+ while (a != b) {
+ if (a > b) a -= b;
+ else b -= a;
+ }
+ return a;
+}
+}
+
+// TODO: get the ggml_type enum here without polution
+//
+bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
+ int neq3, int neq2, long nbq3, long nbq2,
+ int nek3, int nek2, long nbk3, long nbk2,
+ int nev3, int nev2, long nbv3, long nbv2,
+ int ne2, int ne1, long nb1,
+ int int_type_k, // type of k
+ int int_type_v, // type of v
+ int Dk, // K head size
+ int Dv, // V head size
+ int neq1, // number of columns in q
+ int nek1, // number of rows in k
+ int stride_q, // distance between q columns in bytes
+ int stride_k, // distance between k rows in bytes
+ int stride_v, // distance between v rows in bytes
+ int stride_m, // distance between mask rows (in bytes
+ const void * q, // q matrix.
+ const void * k, // k matrix. Assumed to be fp16, nq x nk elements
+ const void * v, // v matrix. Assumed to be fp16, nq x nk elements
+ const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
+ float scale, // scale applied before softmax
+ float softcap, // if > 0, a "soft-cap" operation is applied before softmax
+ float * qkv, // v*softmax(scale*(k*q))
+ [[maybe_unused]] void * work_buffer, [[maybe_unused]] barrier_t barrier, [[maybe_unused]] void * barrier_data,
+ int ith, int nth) {
+
+ if (type_q != 0 || type_mask != 1 || max_bias > 0) return false;
+
+ int rk2 = neq2/nek2;
+ int rv2 = neq2/nev2;
+ int rk3 = neq3/nek3;
+ int rv3 = neq3/nev3;
+
+ // Getting confused all the time about where to load data from and store the results to
+ // (especially when combining the results from the threads).
+ // So, for now, making it work just for MLA (nek2 = 1).
+ // I think it would also speed up things for GQA, but I'm leaving this for another day.
+ if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nth >= 1 && nek1/32 > 1 && nek2 == 1) {
+ int nstep_k = nek1/32;
+ int gcd_k = simple_gcd(nstep_k, nth);
+ if (gcd_k >= 1) {
+ int nth_k = nth/gcd_k;
+ if (rk2%nth_k == 0) {
+ int ith_k = ith%gcd_k;
+ int ith_q = ith/gcd_k;
+ auto kth = (const char *)k + ith_k*(nek1/gcd_k)*stride_k;
+ auto vth = (const char *)v + ith_k*(nek1/gcd_k)*stride_v;
+ auto qth = (const char *)q + ith_q*(rk2/nth_k)*nbq2;
+ auto mth = (const char *)mask + ith_k*(nek1/gcd_k)*sizeof(uint16_t); // we don't have ggml_half available here
+ auto work = (char *)work_buffer;
+
+ // Each thread will produce a result of size Dv*(rk2/nth_k)*sizeof(float)
+ // In addition, we need M, S for the rk2/nth_k rows the thread is processing
+ // => (Dv + 2)*rk2/nth_k*sizeof(float). We use (Dv + 16) instead to make sure threads are not
+ // writing onto the same cache line.
+ auto size_thread = (Dv + 16)*rk2/nth_k*sizeof(float);
+ auto result_buffer = work;
+ auto work_this_thread = (float *)(result_buffer + ith*size_thread);
+ if (!iqk_flash_attn_impl(int_type_k, int_type_v,
+ Dk, Dv, rk2/nth_k, nek1/gcd_k, nbq2, stride_k, stride_v, 0, Dv, //Dk*sizeof(uint16_t), Dv,
+ (const float *)qth, (const void *)kth, (const void *)vth, (const void *)mth,
+ scale, softcap,
+ work_this_thread, work_this_thread + (Dv+0)*rk2/nth_k, work_this_thread + (Dv+1)*rk2/nth_k)) return false;
+
+ barrier(barrier_data);
+
+ // TODO: simdify this
+ for (int j = ith; j < rk2; j += nth) {
+ auto Racc = qkv + j*nb1/sizeof(float);
+ float M = -INFINITY, S = 0;
+ int jth_q = j/(rk2/nth_k);
+ int jj = j%(rk2/nth_k);
+ for (int j1 = 0; j1 < rk2/nth_k; ++j1) {
+ auto R = (const float *)(result_buffer + (jth_q*(rk2/nth_k) + j1)*size_thread);
+ auto Mj = R + Dv*rk2/nth_k;
+ auto Sj = Mj + rk2/nth_k;
+ R += jj*Dv;
+ if (Mj[jj] == -INFINITY) continue;
+ if (Mj[jj] > M) {
+ if (M == -INFINITY) {
+ std::memcpy(Racc, R, Dv*sizeof(float));
+ S = Sj[jj];
+ } else {
+ float c = exp(M - Mj[jj]);
+ S = c*S + Sj[jj];
+ for (int i = 0; i < Dv; ++i) Racc[i] = c*Racc[i] + R[i];
+ }
+ M = Mj[jj];
+ } else {
+ float c = exp(Mj[jj] - M);
+ S += c*Sj[jj];
+ for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i];
+ }
+ }
+ float norm = S > 0 ? 1/S : 1;
+ for (int i = 0; i < Dv; ++i) Racc[i] *= norm;
+ }
+ return true;
+
+ }
+ }
+ }
+
+ // I keep changing my mind what is the best strategy to split the threads when processing
+ // multiple heads. This is my current thinking, the commented out code below was the previous.
+ int ntg = nth/simple_gcd(neq2*neq3, nth);
+ int neq1g = (neq1 + ntg - 1)/ntg;
+ //int64_t work_per_slice = D*nek1*neq1;
+ //int ntg = 1;
+ //
+ // When neq1 is large, it is better to have more than one thread process one (iq2,iq3) matrix
+ // But we also want each thread to process the same amount of rows, so neq1 must be a multiple of
+ // the number of threads processing the (iq2, iq3) matrix.
+ //
+ //if (neq1 >= 8*nth) {
+ // if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8;
+ // else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4;
+ // else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2;
+ //}
+ int counter = 0;
+ for (int64_t iq3 = 0; iq3 < neq3; iq3++) {
+ for (int64_t iq2 = 0; iq2 < neq2; iq2++) {
+ if (counter++ % (nth/ntg) == ith/ntg) {
+ int iq1 = (ith%ntg)*neq1g;
+ int this_neq1 = std::min(neq1g, neq1-iq1);
+ if (!iqk_flash_attn_impl(int_type_k, int_type_v,
+ Dk, Dv, this_neq1, nek1, stride_q, stride_k, stride_v, stride_m, ne1*nb1/sizeof(float),
+ (const float *)((const char *)q + iq2*nbq2 + iq3*nbq3 + iq1*stride_q),
+ (const void *)((const char *)k + iq2/rk2*nbk2 + iq3/rk3*nbk3),
+ (const void *)((const char *)v + iq2/rv2*nbv2 + iq3/rv3*nbv3),
+ (const void *)((const char *)mask + iq1*stride_m),
+ scale, softcap,
+ (float *)((char *)qkv + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1), nullptr, nullptr)) return false;
+ }
+ }
+ }
+
+ return true;
+}
+
+#else
+
+bool iqk_flash_attn_noalibi([[maybe_unused]] int type_q, [[maybe_unused]] int type_mask, [[maybe_unused]] float max_bias,
+ [[maybe_unused]] int neq3, [[maybe_unused]] int neq2, [[maybe_unused]] long nbq3, [[maybe_unused]] long nbq2,
+ [[maybe_unused]] int nek3, [[maybe_unused]] int nek2, [[maybe_unused]] long nbk3, [[maybe_unused]] long nbk2,
+ [[maybe_unused]] int nev3, [[maybe_unused]] int nev2, [[maybe_unused]] long nbv3, [[maybe_unused]] long nbv2,
+ [[maybe_unused]] int ne2, [[maybe_unused]] int ne1, [[maybe_unused]] long nb1,
+ [[maybe_unused]] int int_type_k, // type of k
+ [[maybe_unused]] int int_type_v, // type of v
+ [[maybe_unused]] int D, // head size
+ [[maybe_unused]] int nq, // number of columns in q
+ [[maybe_unused]] int nk, // number of rows in k
+ [[maybe_unused]] int stride_q, // distance between q columns in bytes
+ [[maybe_unused]] int stride_k, // distance between k rows in bytes
+ [[maybe_unused]] int stride_v, // distance between v rows in bytes
+ [[maybe_unused]] int stride_m, // distance between mask rows (in bytes
+ [[maybe_unused]] const void * q, // q matrix.
+ [[maybe_unused]] const void * k, // k matrix. Assumed to be fp16, nq x nk elements
+ [[maybe_unused]] const void * v, // v matrix. Assumed to be fp16, nq x nk elements
+ [[maybe_unused]] const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
+ [[maybe_unused]] float scale, // scale applied before softmax
+ [[maybe_unused]] float softcap, // if > 0, a "soft-cap" operation is applied before softmax
+ [[maybe_unused]] float * qkv, // v*softmax(scale*(k*q))
+ [[maybe_unused]] void * work_buffer, [[maybe_unused]] barrier_t barrier, [[maybe_unused]] void * barrier_data,
+ [[maybe_unused]] int ith, [[maybe_unused]] int nth) {
+ return false;
+}
+
+#endif
+
diff --git a/ggml/src/iqk/iqk_flash_impl.h b/ggml/src/iqk/iqk_flash_impl.h
new file mode 100644
index 00000000..bc68e0d8
--- /dev/null
+++ b/ggml/src/iqk/iqk_flash_impl.h
@@ -0,0 +1,23 @@
+#pragma once
+
+bool iqk_flash_attn_impl(int type_k, // type of k
+ int type_v, // type of v
+ int Dk, // K head size
+ int Dv, // V head size
+ int nq, // number of columns in q
+ int nk, // number of rows in k
+ int stride_q, // distance between q columns in bytes
+ int stride_k, // distance between k rows in bytes
+ int stride_v, // distance between v rows in bytes
+ int stride_m, // distance between mask rows (in bytes
+ int stride_qkv, // distance between rows in mask (in bytes)
+ const float * q, // q matrix.
+ const void * k, // k matrix. Assumed to be fp16, nq x nk elements
+ const void * v, // v matrix. Assumed to be fp16, nq x nk elements
+ const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
+ float scale, // scale applied before softmax
+ float softcap, // if > 0, a "soft-cap" operation is applied before softmax
+ float * qkv, // v*softmax(scale*(k*q))
+ float * M,
+ float * S);
+
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 1f18837c..14cc64db 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -15870,23 +15870,23 @@ struct FlashMS {
inline void update_M_S(int j, float32x4_t * vk) {
float smax = load_and_scale(j, vk);
update_M(j, smax);
- update_S(j, vk);
+ if (M[j] > -INFINITY) update_S(j, vk);
}
inline void update_M_S(int j, float32x4_t * vk, const char * mask) {
float smax = load_apply_mask_and_scale(j, vk, mask);
update_M(j, smax);
- update_S(j, vk);
+ if (M[j] > -INFINITY) update_S(j, vk);
}
#else
inline void update_M_S(int j, F16::Data * vk) {
float smax = load_and_scale(j, vk);
update_M(j, smax);
- update_S(j, vk);
+ if (M[j] > -INFINITY) update_S(j, vk);
}
inline void update_M_S(int j, F16::Data * vk, const char * mask) {
float smax = load_apply_mask_and_scale(j, vk, mask);
update_M(j, smax);
- update_S(j, vk);
+ if (M[j] > -INFINITY) update_S(j, vk);
}
#endif
@@ -16037,7 +16037,7 @@ struct FlashQKV {
}
}
- inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int j, const qkv_cache_t * R, float * qkv) const {
+ inline void normalize_and_store_1row(const FlashMS<q_step, k_step>& fms, int j, const qkv_cache_t * R, float * qkv) const {
GGML_ASSERT(fms.S[j] > 0);
auto norm = F16::set1(1/fms.S[j]);
//auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f);
@@ -16047,21 +16047,43 @@ struct FlashQKV {
}
}
- inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int nq1, int stride_qkv, float * qkv) const {
- auto R = qkv_cache;
- for (int j = 0; j < nq1; ++j) {
- normalize_and_store(fms, j, R, qkv);
- qkv += stride_qkv;
- R += D;
+ inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int nq1, int stride_qkv, float * qkv, float * M, float * S) const {
+ if (M && S) {
+ std::memcpy(M, fms.M, nq1*sizeof(float));
+ std::memcpy(S, fms.S, nq1*sizeof(float));
+ auto R = qkv_cache;
+ for (int j = 0; j < nq1; ++j) {
+ std::memcpy(qkv, R, D*sizeof(float));
+ qkv += stride_qkv;
+ R += D;
+ }
+ } else {
+ auto R = qkv_cache;
+ for (int j = 0; j < nq1; ++j) {
+ normalize_and_store_1row(fms, j, R, qkv);
+ qkv += stride_qkv;
+ R += D;
+ }
}
}
- inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int stride_qkv, float * qkv) const {
- auto R = qkv_cache;
- for (int j = 0; j < q_step; ++j) {
- normalize_and_store(fms, j, R, qkv);
- qkv += stride_qkv;
- R += D;
+ inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int stride_qkv, float * qkv, float * M, float * S) const {
+ if (M && S) {
+ std::memcpy(M, fms.M, q_step*sizeof(float));
+ std::memcpy(S, fms.S, q_step*sizeof(float));
+ auto R = qkv_cache;
+ for (int j = 0; j < q_step; ++j) {
+ std::memcpy(qkv, R, D*sizeof(float));
+ qkv += stride_qkv;
+ R += D;
+ }
+ } else {
+ auto R = qkv_cache;
+ for (int j = 0; j < q_step; ++j) {
+ normalize_and_store_1row(fms, j, R, qkv);
+ qkv += stride_qkv;
+ R += D;
+ }
}
}
@@ -16435,7 +16457,8 @@ template <int Dk, int Dv, int q_step, int k_step, typename KHelper, typename VHe
void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
FlashMS<q_step, k_step>& fms,
FlashQKV<Dv, q_step, k_step>& fqkv,
- const float * q, const char * mask, float * qkv) {
+ const float * q, const char * mask, float * qkv,
+ float * M, float * S) {
#ifdef __aarch64__
float16_t q_f16[Dk*q_step];
#endif
@@ -16459,11 +16482,12 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
vh.next_block();
mr += k_step*sizeof(ggml_half);
}
- fqkv.normalize_and_store(fms, stride_qkv, qkv);
+ fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S);
q += q_step*stride_q;
mask += q_step*stride_m;
qkv += q_step*stride_qkv;
+ if (M && S) { M += q_step; S += q_step; }
}
int n_left = nq1 - q_step*(nq1/q_step);
if (n_left > 0) {
@@ -16485,7 +16509,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
vh.next_block();
mr += k_step*sizeof(ggml_half);
}
- fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv);
+ fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S);
}
}
@@ -16493,7 +16517,8 @@ template <int Dk, int Dv, int q_step, int k_step, typename KHelper, typename VHe
void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
FlashMS<q_step, k_step>& fms,
FlashQKV<Dv, q_step, k_step>& fqkv,
- const float * q, const char * mask, float * qkv) {
+ const float * q, const char * mask, float * qkv,
+ float * M, float * S) {
typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)];
#if FA_TIMING
Perf perf(false);
@@ -16528,15 +16553,16 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
}
#if FA_TIMING
t1 = Perf::cur_time();
- fqkv.normalize_and_store(fms, stride_qkv, qkv);
+ fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S);
perf.accum_nolock(3, t1);
#else
- fqkv.normalize_and_store(fms, stride_qkv, qkv);
+ fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S);
#endif
q += q_step*stride_q;
mask += q_step*stride_m;
qkv += q_step*stride_qkv;
+ if (M && S) { M += q_step; S += q_step; }
}
int n_left = nq1 - q_step*(nq1/q_step);
if (n_left > 0) {
@@ -16552,7 +16578,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
vh.next_block();
mr += k_step*sizeof(ggml_half);
}
- fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv);
+ fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S);
}
#if FA_TIMING
Perf::instance().add(perf);
@@ -16580,12 +16606,12 @@ struct FlashAttn {
template <typename KHelper, typename VHelper>
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
- const float * q, const char * mask, float * qkv) {
+ const float * q, const char * mask, float * qkv, [[maybe_unused]] float * M, [[maybe_unused]] float * S) {
if constexpr (std::is_same_v<KHelper, HelperQ40<Dk, k_step>> || std::is_same_v<KHelper, HelperQ41<Dk, k_step>> ||
std::is_same_v<KHelper, HelperIQ4nl<Dk, k_step>> ||
std::is_same_v<KHelper, HelperQ60<Dk, k_step>>) {
compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
- kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
+ kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
}
else if constexpr (std::is_same_v<KHelper, HelperQ80<Dk, k_step>>) {
if (nq1 >= 8) {
@@ -16597,10 +16623,10 @@ struct FlashAttn {
HelperQ80R8<Dk, k_step> khr4(nk1, kh);
#endif
compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
- khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
+ khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
} else{
compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
- kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
+ kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
}
}
else if constexpr (std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>>) {
@@ -16613,14 +16639,14 @@ struct FlashAttn {
HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
#endif
compute_helper_q<Dk, Dv, q_step, k_step, HelperQ8KVR8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
- khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
+ khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
} else{
compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
- kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
+ kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
}
} else {
compute_helper<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
- kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
+ kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
}
}
@@ -16987,7 +17013,7 @@ struct FlashAttnBF16 {
template <typename KHelper, typename VHelper>
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
- const float * q, const char * mask, float * qkv) {
+ const float * q, const char * mask, float * qkv, [[maybe_unused]] float * M, [[maybe_unused]] float * S) {
ggml_bf16_t q_bf16[q_step*Dk];
#if FA_TIMING
Perf perf(false);
@@ -17023,7 +17049,7 @@ struct FlashAttnBF16 {
#if FA_TIMING
t1 = Perf::cur_time();
#endif
- fqkv.normalize_and_store(fms, stride_qkv, qkv);
+ fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S);
#if FA_TIMING
perf.accum_nolock(4, t1);
#endif
@@ -17046,7 +17072,7 @@ struct FlashAttnBF16 {
vh.next_block();
mr += k_step*sizeof(ggml_half);
}
- fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv);
+ fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S);
}
#if FA_TIMING
Perf::instance().add(perf);
@@ -17060,32 +17086,32 @@ struct FlashAttnBF16 {
template <int Dk, int Dv, int k_step, typename KHelper, typename VHelper>
inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
- const float * q, const char * mask, float scale, float softcap, float * qkv) {
+ const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) {
if (nk1 >= 256) { //4096) {
if (nq1 >= 64) {
FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap);
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
return;
}
if (nq1 >= 32) {
FlashAttn<Dk, Dv, 32, k_step> fa(scale, softcap);
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
return;
}
if (nq1 >= 16) {
FlashAttn<Dk, Dv, 16, k_step> fa(scale, softcap);
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
return;
}
}
if (nq1 >= 8) {
FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap);
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
}
else {
FlashAttn<Dk, Dv, 1, k_step> fa(scale, softcap);
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
}
}
@@ -17093,27 +17119,27 @@ inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int str
template <int Dk, int Dv, int k_step>
inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
const float * q, const char * k, const char * v, const char * mask,
- float scale, float softcap, float * qkv) {
+ float scale, float softcap, float * qkv, float * M, float * S) {
HelperBF16<Dk, k_step> kh(k, stride_k);
HelperBF16<Dv, k_step> vh(v, stride_v);
if (nk1 >= 4096) {
if (nq1 >= 64) {
FlashAttnBF16<Dk, Dv, 64, k_step> fa(scale, softcap);
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
return;
}
else if (nq1 >= 16) {
FlashAttnBF16<Dk, Dv, 16, k_step> fa(scale, softcap);
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
return;
}
}
if (nq1 >= 8) {
FlashAttnBF16<Dk, Dv, 8, k_step> fa(scale, softcap);
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
} else {
FlashAttnBF16<Dk, Dv, 1, k_step> fa(scale, softcap);
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
}
}
#endif
@@ -17122,43 +17148,43 @@ template <int Dk, int Dv, int k_step, typename KHelper>
inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
int nq1, int nk1, int stride_q, int stride_v, int stride_m, int stride_qkv,
const float * q, const char * v, const char * mask,
- float scale, float softcap, float * qkv) {
+ float scale, float softcap, float * qkv, float * M, float * S) {
switch (type_v) {
case GGML_TYPE_F16: {
HelperF16<Dv, k_step> vh(v, stride_v);
- iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
} break;
#ifdef __AVX512BF16__
case GGML_TYPE_BF16: {
HelperBF16<Dv, k_step> vh(v, stride_v);
- iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
} break;
#endif
case GGML_TYPE_Q8_0: {
HelperQ80<Dv, k_step> vh(v, stride_v);
- iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
} break;
case GGML_TYPE_Q8_KV: {
HelperQ8KV<Dv, k_step> vh(v, stride_v);
- iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
} break;
case GGML_TYPE_Q6_0: {
HelperQ60<Dv, k_step> vh(v, stride_v);
- iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
} break;
#if GGML_IQK_FA_ALL_QUANTS
case GGML_TYPE_Q4_0: {
HelperQ40<Dv, k_step> vh(v, stride_v);
- iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
} break;
case GGML_TYPE_Q4_1: {
HelperQ41<Dv, k_step> vh(v, stride_v);
- iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
} break;
case GGML_TYPE_IQ4_NL: {
HelperIQ4nl<Dv, k_step> vh(v, stride_v);
- iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
} break;
#endif
default: break;
@@ -17169,37 +17195,37 @@ template <int Dk, int Dv, int k_step>
inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
const float * q, const char * k, const char * v, const char * mask,
- float scale, float softcap, float * qkv) {
+ float scale, float softcap, float * qkv, float * M, float * S) {
switch (type_k) {
case GGML_TYPE_F16: {
HelperF16<Dk, k_step> kh(k, stride_k);
- iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
+ iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
} break;
case GGML_TYPE_Q8_0: {
HelperQ80<Dk, k_step> kh(k, stride_k);
- iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
+ iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
} break;
case GGML_TYPE_Q8_KV: {
HelperQ8KV<Dk, k_step> kh(k, stride_k);
- iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
+ iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
} break;
case GGML_TYPE_Q6_0: {
HelperQ60<Dk, k_step> kh(k, stride_k);
- iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
+ iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
} break;
#if GGML_IQK_FA_ALL_QUANTS
case GGML_TYPE_Q4_0: {
HelperQ40<Dk, k_step> kh(k, stride_k);
- iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
+ iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
} break;
case GGML_TYPE_Q4_1: {
HelperQ41<Dk, k_step> kh(k, stride_k);
- iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
+ iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
} break;
case GGML_TYPE_IQ4_NL: {
HelperIQ4nl<Dk, k_step> kh(k, stride_k);
- iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
+ iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
} break;
#endif
default: break;
@@ -17223,13 +17249,13 @@ inline bool flash_attn_is_supported(ggml_type type) {
template <int step_k, typename KHelper, typename VHelper>
inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh,
int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
- const float * q, const char * mask, float scale, float softcap, float * qkv) {
+ const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) {
if (nq1 % 8 == 0) {
FlashAttn<576, 512, 8, step_k> fa(scale, softcap);
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
} else {
FlashAttn<576, 512, 1, step_k> fa(scale, softcap);
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
}
}
@@ -17237,29 +17263,29 @@ template <int step_k>
inline bool iqk_deepseek_helper(ggml_type type_k,
int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
const float * q, const char * k, const char * v, const char * mask,
- float scale, float softcap, float * qkv) {
+ float scale, float softcap, float * qkv, float * M, float * S) {
if (type_k == GGML_TYPE_Q8_0) {
HelperQ80<576, step_k> kh((const char *)k, stride_k);
HelperQ80<512, step_k> vh((const char *)v, stride_v);
- iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
return true;
}
if (type_k == GGML_TYPE_Q6_0) {
HelperQ60<576, step_k> kh((const char *)k, stride_k);
HelperQ60<512, step_k> vh((const char *)v, stride_v);
- iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
return true;
}
if (type_k == GGML_TYPE_Q8_KV) {
HelperQ8KV<576, step_k> kh((const char *)k, stride_k);
HelperQ8KV<512, step_k> vh((const char *)v, stride_v);
- iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
return true;
}
if (type_k == GGML_TYPE_F16) {
HelperF16<576, step_k> kh((const char *)k, stride_k);
HelperF16<512, step_k> vh((const char *)v, stride_v);
- iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
return true;
}
#ifdef __AVX512BF16__
@@ -17268,10 +17294,10 @@ inline bool iqk_deepseek_helper(ggml_type type_k,
HelperBF16<512, step_k> vh((const char *)v, stride_v);
if (nq1 % 8 == 0) {
FlashAttnBF16<576, 512, 8, step_k> fa(scale, softcap);
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
} else {
FlashAttnBF16<576, 512, 1, step_k> fa(scale, softcap);
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
}
return true;
}
@@ -17281,24 +17307,27 @@ inline bool iqk_deepseek_helper(ggml_type type_k,
}
-bool iqk_flash_attn_noalibi(int int_type_k, // type of k
- int int_type_v, // type of v
- int Dk, // K head size
- int Dv, // V head size
- int nq1, // number of columns in q
- int nk1, // number of rows in k
- int stride_q, // distance between q columns in bytes
- int stride_k, // distance between k rows in bytes
- int stride_v, // distance between v rows in bytes
- int stride_m, // distance between mask rows (in bytes
- int stride_qkv, // distance between rows in mask (in bytes)
- const float * q, // q matrix.
- const void * k, // k matrix. Assumed to be fp16, nq x nk elements
- const void * v, // v matrix. Assumed to be fp16, nq x nk elements
- const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
- float scale, // scale applied before softmax
- float softcap, // if > 0, a "soft-cap" operation is applied before softmax
- float * qkv) { // v*softmax(scale*(k*q))
+#include "iqk_flash_impl.h"
+
+bool iqk_flash_attn_impl(int int_type_k, // type of k
+ int int_type_v, // type of v
+ int Dk, // K head size
+ int Dv, // V head size
+ int nq1, // number of columns in q
+ int nk1, // number of rows in k
+ int stride_q, // distance between q columns in bytes
+ int stride_k, // distance between k rows in bytes
+ int stride_v, // distance between v rows in bytes
+ int stride_m, // distance between mask rows (in bytes
+ int stride_qkv, // distance between rows in mask (in bytes)
+ const float * q, // q matrix.
+ const void * k, // k matrix. Assumed to be fp16, nq x nk elements
+ const void * v, // v matrix. Assumed to be fp16, nq x nk elements
+ const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
+ float scale, // scale applied before softmax
+ float softcap, // if > 0, a "soft-cap" operation is applied before softmax
+ float * qkv, // v*softmax(scale*(k*q))
+ float * M, float * S) {
if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32
@@ -17309,13 +17338,13 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
GGML_ASSERT(type_k == type_v);
stride_q /= sizeof(float); // q stride as float
return iqk_deepseek_helper<32>(type_k, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
- q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv);
+ q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv, M, S);
}
if (!flash_attn_is_supported(type_k) || !flash_attn_is_supported(type_v)) return false;
if (Dk != Dv && Dk != 192 && Dv != 128) return false;
if (Dv != 64 && Dv != 96 && Dv != 128 && Dv != 256) return false;
- if (Dk != 64 && Dk != 96 && Dk != 128 && Dk != 192 && Dv != 256) return false;
+ if (Dk != 64 && Dk != 96 && Dk != 128 && Dk != 192 && Dk != 256) return false;
auto ck = (const char *)k;
auto cv = (const char *)v;
@@ -17329,15 +17358,15 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
switch (Dk) {
case 64:
- iqk_flash_helper_T< 64, 64, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T< 64, 64, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
case 96:
- iqk_flash_helper_T< 96, 96, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T< 96, 96, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
case 128:
- iqk_flash_helper_T<128, 128, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T<128, 128, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
case 192:
- iqk_flash_helper_T<192, 128, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T<192, 128, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
case 256:
- iqk_flash_helper_T<256, 256, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T<256, 256, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
default:
return false;
}
@@ -17346,15 +17375,15 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
switch (Dk) {
case 64:
- iqk_flash_helper_T< 64, 64, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T< 64, 64, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
case 96:
- iqk_flash_helper_T< 96, 96, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T< 96, 96, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
case 128:
- iqk_flash_helper_T<128, 128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T<128, 128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
case 192:
- iqk_flash_helper_T<192, 128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T<192, 128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
case 256:
- iqk_flash_helper_T<256, 256, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T<256, 256, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
default:
return false;
}
@@ -17366,21 +17395,21 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
if (nk1%64 == 0) {
switch (Dk) {
case 64:
- iqk_flash_helper_T< 64, 64, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T< 64, 64, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
// Disable until we fix accumulate_qkv for odd D/16
//case 80:
// iqk_flash_helper_T< 80, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 96:
- iqk_flash_helper_T< 96, 96, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T< 96, 96, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
// Disable until we fix accumulate_qkv for odd D/16
//case 112:
// iqk_flash_helper_T<112, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 128:
- iqk_flash_helper_T<128, 128, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T<128, 128, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
case 192:
- iqk_flash_helper_T<192, 128, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T<192, 128, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
case 256:
- iqk_flash_helper_T<256, 256, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T<256, 256, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
default:
return false;
}
@@ -17388,21 +17417,21 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
}
switch (Dk) {
case 64:
- iqk_flash_helper_T< 64, 64, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T< 64, 64, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
// Disable until we fix accumulate_qkv for odd D/16
//case 80:
// iqk_flash_helper_T< 80, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 96:
- iqk_flash_helper_T< 96, 96, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T< 96, 96, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
// Disable until we fix accumulate_qkv for odd D/16
//case 112:
// iqk_flash_helper_T<112, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 128:
- iqk_flash_helper_T<128, 128, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T<128, 128, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
case 192:
- iqk_flash_helper_T<192, 128, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T<192, 128, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
case 256:
- iqk_flash_helper_T<256, 256, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T<256, 256, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break;
default:
return false;
}
@@ -17437,25 +17466,4 @@ bool iqk_moe_fused_up_gate(long /*Nx*/, long /*Ny*/, long /*ne00*/, int /*ne11*/
return false;
}
-
-bool iqk_flash_attn_noalibi([[maybe_unused]] int int_type_k, // type of k
- [[maybe_unused]] int int_type_v, // type of v
- [[maybe_unused]] int D, // head size
- [[maybe_unused]] int nq, // number of columns in q
- [[maybe_unused]] int nk, // number of rows in k
- [[maybe_unused]] int stride_q, // distance between q columns in bytes
- [[maybe_unused]] int stride_k, // distance between k rows in bytes
- [[maybe_unused]] int stride_v, // distance between v rows in bytes
- [[maybe_unused]] int stride_m, // distance between mask rows (in bytes
- [[maybe_unused]] int stride_qkv, // distance between rows in mask (in bytes)
- [[maybe_unused]] const float * q, // q matrix.
- [[maybe_unused]] const void * k, // k matrix. Assumed to be fp16, nq x nk elements
- [[maybe_unused]] const void * v, // v matrix. Assumed to be fp16, nq x nk elements
- [[maybe_unused]] const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
- [[maybe_unused]] float scale, // scale applied before softmax
- [[maybe_unused]] float softcap, // if > 0, a "soft-cap" operation is applied before softmax
- [[maybe_unused]] float * qkv) { // v*softmax(scale*(k*q))
- return false;
-}
-
#endif
diff --git a/ggml/src/iqk/iqk_mul_mat.h b/ggml/src/iqk/iqk_mul_mat.h
index 767f89cf..d91c4710 100644
--- a/ggml/src/iqk/iqk_mul_mat.h
+++ b/ggml/src/iqk/iqk_mul_mat.h
@@ -33,7 +33,14 @@ bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op,
int typeB, const void * B, long strideB,
float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth);
-bool iqk_flash_attn_noalibi(int type_k, // type of k
+typedef void (*barrier_t) (void *);
+
+bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
+ int neq3, int neq2, long nbq3, long nbq2,
+ int nek3, int nek2, long nbk3, long nbk2,
+ int nev3, int nev2, long nbv3, long nbv2,
+ int ne2, int ne1, long nb1,
+ int type_k, // type of k
int type_v, // type of v
int Dk, // K head size
int Dv, // V head size
@@ -43,14 +50,15 @@ bool iqk_flash_attn_noalibi(int type_k, // type of k
int stride_k, // distance between k rows in bytes
int stride_v, // distance between v rows in bytes
int stride_m, // distance between mask rows (in bytes
- int stride_qkv, // distance between rows in mask (in bytes)
- const float * q, // q matrix.
+ const void * q, // q matrix.
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
const void * v, // v matrix. Assumed to be fp16, nq x nk elements
const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
float scale, // scale applied before softmax
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
- float * qkv); // v*softmax(scale*(k*q))
+ float * qkv, // v*softmax(scale*(k*q))
+ void * work_buffer, barrier_t barrier, void * barrier_data,
+ int ith, int nth);
#ifdef __cplusplus
}