summaryrefslogtreecommitdiff
path: root/ggml/src/iqk/iqk_mul_mat.cpp
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-04-17 08:08:40 +0200
committerGitHub <noreply@github.com>2025-04-17 08:08:40 +0200
commit3bb64d9330d5336d76b036535474d8a4b273373c (patch)
tree3724f7c8abc20b467b756f8a498be7c619831a68 /ggml/src/iqk/iqk_mul_mat.cpp
parentf7c5a94e756e4add4d531d295ae23493d9857508 (diff)
Better TG performance for GQA models (CPU) (#332)
* Slightly better CPU TG performance for GQA * Better CPU FA implementation for TG when GQA * Minor --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src/iqk/iqk_mul_mat.cpp')
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp53
1 files changed, 53 insertions, 0 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 78270f5e..424a65af 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -451,6 +451,51 @@ bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
auto r3 = ne13 / ne03;
if (ne13 == 1 && Ny == 1 && r2 > 1) {
+ if (Nx >= 256 && Nx%32 == 0) {
+ int nx32 = Nx/32;
+ int nchunk = nx32*ne02;
+ if (r2 <= 8) {
+ MulMat mm;
+ if (!MulMat::prepare(typeA, typeB, ne00, mm, r2)) return false;
+ int nx64 = Nx/64;
+ int nchunk64 = nx64*ne02;
+ for (int ichunk = ith; ichunk < nchunk64; ichunk += nth) {
+ int i02 = ichunk/nx64;
+ int ix = 64*(ichunk - i02*nx64);
+ DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
+ mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 64);
+ }
+ int ix0 = 64*nx64;
+ if (ix0 < Nx) {
+ nx32 -= 2*nx64;
+ nchunk = nx32*ne02;
+ for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
+ int i02 = ichunk/nx32;
+ int ix = ix0 + 32*(ichunk - i02*nx32);
+ DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
+ mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32);
+ }
+ }
+ //for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
+ // int i02 = ichunk/nx32;
+ // int ix = 32*(ichunk - i02*nx32);
+ // DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
+ // mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32);
+ //}
+ return true;
+ }
+ for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
+ int i02 = ichunk/nx32;
+ int ix = ichunk - i02*nx32;
+ if (!iqk_mul_mat(32, r2, ne00,
+ typeA, (const char *)A + 32*ix*strideA + i02*nb02, strideA,
+ typeB, (const char *)B + i02*r2*nb12, nb12,
+ C + 32*ix + r2*i02*nb2, nb2, 0, 1)) return false;
+
+ }
+ return true;
+ }
+ //if (ith == 0) printf("Using this: Nx = %d, r2 = %d, ne02 = %d\n", (int)Nx, (int)r2,(int)ne02);
int gcd = simple_gcd(ne02, nth);
int counter = 0;
for (int64_t i12 = 0; i12 < ne02; i12++) {
@@ -17153,6 +17198,14 @@ inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int str
FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
}
+ else if (nq1 >= 4) {
+ FlashAttn<Dk, Dv, 4, k_step> fa(scale, softcap);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
+ }
+ else if (nq1 >= 2) {
+ FlashAttn<Dk, Dv, 2, k_step> fa(scale, softcap);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
+ }
else {
FlashAttn<Dk, Dv, 1, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);