summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-05-15 08:43:39 +0300
committerGitHub <noreply@github.com>2025-05-15 08:43:39 +0300
commit3f8c865b920df844ba0cb4ba53c1ccce8874b045 (patch)
tree81c7f9d40578ca66f06941ec5118ff068bb9347c
parent14ed9fb44da5212b4334277606e47c7040888a8a (diff)
Fix standard attention on the CPU (#421)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp26
1 files changed, 6 insertions, 20 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 92f58d55..6c3a3575 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -461,27 +461,15 @@ extern "C" IQK_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
int ny = mm.funcs.size();
while (ny > 0 && !mm.funcs[ny-1]) --ny;
if (ny >= r2) {
- int nx64 = Nx/64;
- int nchunk64 = nx64*ne02;
- for (int ichunk = ith; ichunk < nchunk64; ichunk += nth) {
- int i02 = ichunk/nx64;
- int ix = 64*(ichunk - i02*nx64);
+ nchunk = nx32*ne02;
+ for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
+ int i02 = ichunk/nx32;
+ int ix = 32*(ichunk - i02*nx32);
DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
- mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 64);
- }
- int ix0 = 64*nx64;
- if (ix0 < Nx) {
- nx32 -= 2*nx64;
- nchunk = nx32*ne02;
- for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
- int i02 = ichunk/nx32;
- int ix = ix0 + 32*(ichunk - i02*nx32);
- DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
- mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32);
- }
+ mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32);
}
+ return true;
}
- return true;
}
for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
int i02 = ichunk/nx32;
@@ -494,7 +482,6 @@ extern "C" IQK_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
}
return true;
}
- //if (ith == 0) printf("Using this: Nx = %d, r2 = %d, ne02 = %d\n", (int)Nx, (int)r2,(int)ne02);
int gcd = simple_gcd(ne02, nth);
int counter = 0;
for (int64_t i12 = 0; i12 < ne02; i12++) {
@@ -510,7 +497,6 @@ extern "C" IQK_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
}
if (ne13 == 1 && ne12 > 1 && ne12 == ne02 && Ny == 1 && nb02 < strideA) {
- //printf("TG attention gemm for %d heads and Nx = %d\n", (int)ne02, (int)Nx);
MulMat mm;
if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) {
return false;