diff options
Diffstat (limited to 'ggml/src/iqk/iqk_mul_mat.cpp')
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 53 |
1 files changed, 53 insertions, 0 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 78270f5e..424a65af 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -451,6 +451,51 @@ bool iqk_mul_mat_4d(long Nx, long Ny, long ne00, auto r3 = ne13 / ne03; if (ne13 == 1 && Ny == 1 && r2 > 1) { + if (Nx >= 256 && Nx%32 == 0) { + int nx32 = Nx/32; + int nchunk = nx32*ne02; + if (r2 <= 8) { + MulMat mm; + if (!MulMat::prepare(typeA, typeB, ne00, mm, r2)) return false; + int nx64 = Nx/64; + int nchunk64 = nx64*ne02; + for (int ichunk = ith; ichunk < nchunk64; ichunk += nth) { + int i02 = ichunk/nx64; + int ix = 64*(ichunk - i02*nx64); + DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0}; + mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 64); + } + int ix0 = 64*nx64; + if (ix0 < Nx) { + nx32 -= 2*nx64; + nchunk = nx32*ne02; + for (int ichunk = ith; ichunk < nchunk; ichunk += nth) { + int i02 = ichunk/nx32; + int ix = ix0 + 32*(ichunk - i02*nx32); + DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0}; + mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32); + } + } + //for (int ichunk = ith; ichunk < nchunk; ichunk += nth) { + // int i02 = ichunk/nx32; + // int ix = 32*(ichunk - i02*nx32); + // DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0}; + // mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32); + //} + return true; + } + for (int ichunk = ith; ichunk < nchunk; ichunk += nth) { + int i02 = ichunk/nx32; + int ix = ichunk - i02*nx32; + if (!iqk_mul_mat(32, r2, ne00, + typeA, (const char *)A + 32*ix*strideA + i02*nb02, strideA, + typeB, (const char *)B + i02*r2*nb12, nb12, + C + 32*ix + r2*i02*nb2, nb2, 0, 1)) return false; + + } + return true; + } + //if (ith == 0) printf("Using this: Nx = %d, r2 = %d, ne02 = %d\n", (int)Nx, (int)r2,(int)ne02); int gcd = simple_gcd(ne02, nth); int counter = 0; for (int64_t i12 = 0; i12 < ne02; i12++) { @@ -17153,6 +17198,14 @@ inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int str FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); } + else if (nq1 >= 4) { + FlashAttn<Dk, Dv, 4, k_step> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + } + else if (nq1 >= 2) { + FlashAttn<Dk, Dv, 2, k_step> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + } else { FlashAttn<Dk, Dv, 1, k_step> fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); |