diff options
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 58 |
1 files changed, 56 insertions, 2 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 5750b952..fffddada 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -352,6 +352,26 @@ bool iqk_mul_mat_4d(long Nx, long Ny, long ne00, } return true; } + + if (ne13 == 1 && ne12 > 1 && ne12 == ne02 && Ny == 1 && nb02 < strideA) { + //printf("TG attention gemm for %d heads and Nx = %d\n", (int)ne02, (int)Nx); + MulMat mm; + if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) { + return false; + } + int n_per_thread = (Nx + nth - 1)/nth; + int first = ith*n_per_thread; + if (first >= Nx) return true; + int last = first + n_per_thread <= Nx ? first + n_per_thread : Nx; + for (int ix = first; ix < last; ++ix) { + for (int i02 = 0; i02 < ne02; ++i02) { + DataInfo info{C + ix + i02*nb2, (const char *)B + i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0}; + mm.funcs[0](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), nb02, info, 1); + } + } + return true; + } + int gcd = simple_gcd(ne12*ne13, nth); int counter = 0; for (int64_t i13 = 0; i13 < ne13; i13++) { @@ -6229,8 +6249,8 @@ static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataIn // The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__) template <int nrc_y> static void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%8 == 0); GGML_ASSERT(n%32 == 0); + GGML_ASSERT(nrc_x%8 == 0); #ifndef HAVE_FANCY_SIMD auto m1 = _mm256_set1_epi16(1); #endif @@ -6298,8 +6318,42 @@ static void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const Data template <int nrc_y> static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%8 == 0); GGML_ASSERT(n%32 == 0); + if (nrc_y == 1 && nrc_x == 1) { + auto dx = (const float *)vx; + auto dy = (const float *)info.src1_row(0); +#ifdef HAVE_FANCY_SIMD + auto sy = (const int32_t *)(dy + 1); + auto x = (const int8_t *)(dx + 2); + auto y = (const int8_t *)(dy + 2); + auto isum = _mm512_setzero_si512(); + for (int i = 0; i < n/64; ++i) { + auto qx = _mm512_loadu_si512((const __m512i *)x + i); + auto qy = _mm512_loadu_si512((const __m512i *)y + i); + isum = _mm512_dpbusd_epi32(isum, _mm512_add_epi8(qx, _mm512_set1_epi8(127)), qy); + } + auto isum256 = _mm256_add_epi32(_mm512_castsi512_si256(isum), _mm512_extracti32x8_epi32(isum, 1)); + for (int i = 2*(n/64); i < n/32; ++i) { + auto qx = _mm256_loadu_si256((const __m256i *)x + i); + auto qy = _mm256_loadu_si256((const __m256i *)y + i); + isum256 = _mm256_dpbusd_epi32(isum256, _mm256_add_epi8(qx, _mm256_set1_epi8(127)), qy); + } + info.store(0, 0, dx[0]*dy[0]*(hsum_i32_8(isum256) - 127*sy[0])); +#else + auto x = (const int8_t *)(dx + 2); + auto y = (const int8_t *)(dy + 2); + auto isum = _mm256_setzero_si256(); + for (int i = 0; i < n/32; ++i) { + auto qx = _mm256_loadu_si256((const __m256i *)x + i); + auto qy = _mm256_loadu_si256((const __m256i *)y + i); + auto dot = _mm256_maddubs_epi16(_mm256_sign_epi8(qx, qx), _mm256_sign_epi8(qy, qx)); + isum = _mm256_add_epi32(isum, _mm256_madd_epi16(_mm256_set1_epi16(1), dot)); + } + info.store(0, 0, dx[0]*dy[0]*hsum_i32_8(isum)); +#endif + return; + } + GGML_ASSERT(nrc_x%8 == 0); __m256i qx[2]; __m256i acc[2*nrc_y] = {}; float dy[nrc_y]; |