summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-02-22 09:38:51 +0200
committerGitHub <noreply@github.com>2025-02-22 09:38:51 +0200
commitc4a5103299e44adc8692e3e373c1974fa9fee270 (patch)
treef0afc8baa6af5e7805835c76c711f0bc58771f73
parentb9a6639ac3bc77c64bba679cb85b14de0c4a9c9d (diff)
Better strategy for attention matrix multiplications when generating tokens (#218)
* This seems to be a better way to do the attention matrix multiplications in the TG case. * Cleanup --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp58
1 files changed, 56 insertions, 2 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 5750b952..fffddada 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -352,6 +352,26 @@ bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
}
return true;
}
+
+ if (ne13 == 1 && ne12 > 1 && ne12 == ne02 && Ny == 1 && nb02 < strideA) {
+ //printf("TG attention gemm for %d heads and Nx = %d\n", (int)ne02, (int)Nx);
+ MulMat mm;
+ if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) {
+ return false;
+ }
+ int n_per_thread = (Nx + nth - 1)/nth;
+ int first = ith*n_per_thread;
+ if (first >= Nx) return true;
+ int last = first + n_per_thread <= Nx ? first + n_per_thread : Nx;
+ for (int ix = first; ix < last; ++ix) {
+ for (int i02 = 0; i02 < ne02; ++i02) {
+ DataInfo info{C + ix + i02*nb2, (const char *)B + i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
+ mm.funcs[0](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), nb02, info, 1);
+ }
+ }
+ return true;
+ }
+
int gcd = simple_gcd(ne12*ne13, nth);
int counter = 0;
for (int64_t i13 = 0; i13 < ne13; i13++) {
@@ -6229,8 +6249,8 @@ static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataIn
// The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__)
template <int nrc_y>
static void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%8 == 0);
GGML_ASSERT(n%32 == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
#ifndef HAVE_FANCY_SIMD
auto m1 = _mm256_set1_epi16(1);
#endif
@@ -6298,8 +6318,42 @@ static void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const Data
template <int nrc_y>
static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%8 == 0);
GGML_ASSERT(n%32 == 0);
+ if (nrc_y == 1 && nrc_x == 1) {
+ auto dx = (const float *)vx;
+ auto dy = (const float *)info.src1_row(0);
+#ifdef HAVE_FANCY_SIMD
+ auto sy = (const int32_t *)(dy + 1);
+ auto x = (const int8_t *)(dx + 2);
+ auto y = (const int8_t *)(dy + 2);
+ auto isum = _mm512_setzero_si512();
+ for (int i = 0; i < n/64; ++i) {
+ auto qx = _mm512_loadu_si512((const __m512i *)x + i);
+ auto qy = _mm512_loadu_si512((const __m512i *)y + i);
+ isum = _mm512_dpbusd_epi32(isum, _mm512_add_epi8(qx, _mm512_set1_epi8(127)), qy);
+ }
+ auto isum256 = _mm256_add_epi32(_mm512_castsi512_si256(isum), _mm512_extracti32x8_epi32(isum, 1));
+ for (int i = 2*(n/64); i < n/32; ++i) {
+ auto qx = _mm256_loadu_si256((const __m256i *)x + i);
+ auto qy = _mm256_loadu_si256((const __m256i *)y + i);
+ isum256 = _mm256_dpbusd_epi32(isum256, _mm256_add_epi8(qx, _mm256_set1_epi8(127)), qy);
+ }
+ info.store(0, 0, dx[0]*dy[0]*(hsum_i32_8(isum256) - 127*sy[0]));
+#else
+ auto x = (const int8_t *)(dx + 2);
+ auto y = (const int8_t *)(dy + 2);
+ auto isum = _mm256_setzero_si256();
+ for (int i = 0; i < n/32; ++i) {
+ auto qx = _mm256_loadu_si256((const __m256i *)x + i);
+ auto qy = _mm256_loadu_si256((const __m256i *)y + i);
+ auto dot = _mm256_maddubs_epi16(_mm256_sign_epi8(qx, qx), _mm256_sign_epi8(qy, qx));
+ isum = _mm256_add_epi32(isum, _mm256_madd_epi16(_mm256_set1_epi16(1), dot));
+ }
+ info.store(0, 0, dx[0]*dy[0]*hsum_i32_8(isum));
+#endif
+ return;
+ }
+ GGML_ASSERT(nrc_x%8 == 0);
__m256i qx[2];
__m256i acc[2*nrc_y] = {};
float dy[nrc_y];