diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-05-15 08:43:39 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-05-15 08:43:39 +0300 |
commit | 3f8c865b920df844ba0cb4ba53c1ccce8874b045 (patch) | |
tree | 81c7f9d40578ca66f06941ec5118ff068bb9347c | |
parent | 14ed9fb44da5212b4334277606e47c7040888a8a (diff) |
Fix standard attention on the CPU (#421)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 26 |
1 files changed, 6 insertions, 20 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 92f58d55..6c3a3575 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -461,27 +461,15 @@ extern "C" IQK_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00, int ny = mm.funcs.size(); while (ny > 0 && !mm.funcs[ny-1]) --ny; if (ny >= r2) { - int nx64 = Nx/64; - int nchunk64 = nx64*ne02; - for (int ichunk = ith; ichunk < nchunk64; ichunk += nth) { - int i02 = ichunk/nx64; - int ix = 64*(ichunk - i02*nx64); + nchunk = nx32*ne02; + for (int ichunk = ith; ichunk < nchunk; ichunk += nth) { + int i02 = ichunk/nx32; + int ix = 32*(ichunk - i02*nx32); DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0}; - mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 64); - } - int ix0 = 64*nx64; - if (ix0 < Nx) { - nx32 -= 2*nx64; - nchunk = nx32*ne02; - for (int ichunk = ith; ichunk < nchunk; ichunk += nth) { - int i02 = ichunk/nx32; - int ix = ix0 + 32*(ichunk - i02*nx32); - DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0}; - mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32); - } + mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32); } + return true; } - return true; } for (int ichunk = ith; ichunk < nchunk; ichunk += nth) { int i02 = ichunk/nx32; @@ -494,7 +482,6 @@ extern "C" IQK_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00, } return true; } - //if (ith == 0) printf("Using this: Nx = %d, r2 = %d, ne02 = %d\n", (int)Nx, (int)r2,(int)ne02); int gcd = simple_gcd(ne02, nth); int counter = 0; for (int64_t i12 = 0; i12 < ne02; i12++) { @@ -510,7 +497,6 @@ extern "C" IQK_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00, } if (ne13 == 1 && ne12 > 1 && ne12 == ne02 && Ny == 1 && nb02 < strideA) { - //printf("TG attention gemm for %d heads and Nx = %d\n", (int)ne02, (int)Nx); MulMat mm; if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) { return false; |