summaryrefslogtreecommitdiff
path: root/ggml/src/ggml-quants.c
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-01-15 18:19:22 +0200
committerGitHub <noreply@github.com>2025-01-15 18:19:22 +0200
commit0b74397d596bbcdfba27299393406d2b6330b133 (patch)
tree2101d059f79b6b268086c71878aa2da1c328c73d /ggml/src/ggml-quants.c
parent49b27069fd267d3dac8de5d13141b4274e4be16b (diff)
CPU Flash Attention improvements (#172)
* Slightly faster FA for bf16 KV cache ~2-3% sort of thing. Sadly, when we go beyond 8k tokens, the advantage kind of goes away. * Slightly faster FA for Q8_0 KV cache * FA: allow bf16 for V-cache with any supported K-cache E.g., -ctk q8_0 -ctv bf16 is slightly faster than -ctk q8_0 -ctv q8_0 on Zen4 for not too long context lengths (say, <= 4096). * FA: much better bf16 kv-cache speed for large contexts We now hit 122 t/s for LLaMA-3.1-8B (quantized as iq4_xs and run-time-repacked) with a context of 32768. IIRC, the previous best for such large context was ~90 t/s. Non-negligible improvement at 16384 and 8192 as well: 173.4 and 214 t/s. * FA: slightly better quantized kv-cache speed for large contexts E.g., for q8_0 and context of 32768, we are now at 113 t/s for LLaMA-3.1-8B. Also simplified the quantized K*Q multiplication. * Fix q8_0 KV cache when not using FA - WIP (AVX2) 1. We add new types GGML_TYPE_Q8_0_X4 and GGML_TYPE_Q8_1_X4, and use those to quantize activations for quants that use Q8_0 or Q8_1 as their vec_dot type. 2. We revert the changes to quantize_row_q8_0 and quantize_row_q8_1 3. We use GGML_TYPE_Q8_0_X4 and GGML_TYPE_Q8_1_X4 as the vec_dot type 4. We change the FA implementation to use GGML_TYPE_Q8_0 rather than GGML_TYPE_Q8_0_X4 as the K and V types 5. We change the expected type to GGML_TYPE_Q8_0_X4/GGML_TYPE_Q8_1_X4 in iqk_mul_mat Also added an optimization in ggml_compute_forward_mul_mat when ne12*ne13 > 1 (K*Q and V*softmax(K*Q)) to process n12*ne13/GCD(n12*ne13, nthread) threads simultaneously using nthread/GCD(n12*ne13, nthread) threads per head. This results in a non-negligible performance gain for large contexts. Question: why is it not allowed to use quantized V-cache when not using FA? * Fix q8_0 KV cache when not using FA - NEON * Fix AVX2 Again the issue with _mm256_maddubs_epi16 overflowing that I keep forgetting. * FA: don't use large Q steps on AVX2 for fp16 K-cache * On Zen4 it is also better to not use large Q steps for fp16 K-cache --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src/ggml-quants.c')
-rw-r--r--ggml/src/ggml-quants.c105
1 files changed, 16 insertions, 89 deletions
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
index d460b84a..23ac9915 100644
--- a/ggml/src/ggml-quants.c
+++ b/ggml/src/ggml-quants.c
@@ -934,13 +934,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
block_q8_0 * restrict y = vy;
-#if GGML_USE_IQK_MULMAT
- const int nb4 = 4*(nb/4);
-#else
- const int nb4 = -1;
-#endif
#if defined(__ARM_NEON)
- block_q8_0_x4 * y4 = (block_q8_0_x4 *)vy;
for (int i = 0; i < nb; i++) {
int i4 = i/4, ir = i%4;
float32x4_t srcv [8];
@@ -959,27 +953,16 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
- if (i < nb4) {
- y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
- } else {
- y[i].d = GGML_FP32_TO_FP16(d);
- }
+ y[i].d = GGML_FP32_TO_FP16(d);
for (int j = 0; j < 8; j++) {
const float32x4_t v = vmulq_n_f32(srcv[j], id);
const int32x4_t vi = vcvtnq_s32_f32(v);
- if (i < nb4) {
- y4[i4].qs[32*ir + 4*j + 0] = vgetq_lane_s32(vi, 0);
- y4[i4].qs[32*ir + 4*j + 1] = vgetq_lane_s32(vi, 1);
- y4[i4].qs[32*ir + 4*j + 2] = vgetq_lane_s32(vi, 2);
- y4[i4].qs[32*ir + 4*j + 3] = vgetq_lane_s32(vi, 3);
- } else {
- y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
- y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
- y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
- y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
- }
+ y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
+ y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
+ y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
+ y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
}
}
#elif defined(__wasm_simd128__)
@@ -1016,14 +999,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
}
}
#elif defined(__AVX2__) || defined(__AVX__)
- block_q8_0_x4 * y4 = (block_q8_0_x4 *)vy;
-#ifdef __AVX2__
- const bool pack = true;
-#else
- const bool pack = false;
-#endif
for (int i = 0; i < nb; i++) {
- int i4 = i/4, ir = i%4;
// Load elements into 4 AVX vectors
__m256 v0 = _mm256_loadu_ps( x );
__m256 v1 = _mm256_loadu_ps( x + 8 );
@@ -1045,11 +1021,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
// Quantize these floats
const float d = maxScalar / 127.f;
- if (pack && i < nb4) {
- y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
- } else {
- y[i].d = GGML_FP32_TO_FP16(d);
- }
+ y[i].d = GGML_FP32_TO_FP16(d);
const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
const __m256 mul = _mm256_set1_ps( id );
@@ -1084,11 +1056,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
i0 = _mm256_permutevar8x32_epi32( i0, perm );
- if (i < nb4) {
- _mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0);
- } else {
- _mm256_storeu_si256((__m256i *)y[i].qs, i0);
- }
+ _mm256_storeu_si256((__m256i *)y[i].qs, i0);
#else
// Since we don't have in AVX some necessary functions,
// we split the registers in half and call AVX2 analogs from SSE
@@ -1287,15 +1255,8 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
block_q8_1 * restrict y = vy;
-#if GGML_USE_IQK_MULMAT
- const int nb4 = 4*(nb/4);
-#else
- const int nb4 = -1;
-#endif
#if defined(__ARM_NEON)
- block_q8_1_x4 * restrict y4 = vy;
for (int i = 0; i < nb; i++) {
- int i4 = i/4, ir = i%4;
float32x4_t srcv [8];
float32x4_t asrcv[8];
float32x4_t amaxv[8];
@@ -1312,11 +1273,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
- if (i < nb4) {
- y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
- } else {
- y[i].d = GGML_FP32_TO_FP16(d);
- }
+ y[i].d = GGML_FP32_TO_FP16(d);
int32x4_t accv = vdupq_n_s32(0);
@@ -1324,26 +1281,15 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
const float32x4_t v = vmulq_n_f32(srcv[j], id);
const int32x4_t vi = vcvtnq_s32_f32(v);
- if (i < nb4) {
- y4[i4].qs[QK8_1*ir + 4*j + 0] = vgetq_lane_s32(vi, 0);
- y4[i4].qs[QK8_1*ir + 4*j + 1] = vgetq_lane_s32(vi, 1);
- y4[i4].qs[QK8_1*ir + 4*j + 2] = vgetq_lane_s32(vi, 2);
- y4[i4].qs[QK8_1*ir + 4*j + 3] = vgetq_lane_s32(vi, 3);
- } else {
- y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
- y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
- y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
- y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
- }
+ y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
+ y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
+ y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
+ y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
accv = vaddq_s32(accv, vi);
}
- if (i < nb4) {
- y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * vaddvq_s32(accv));
- } else {
- y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv));
- }
+ y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv));
}
#elif defined(__wasm_simd128__)
for (int i = 0; i < nb; i++) {
@@ -1389,14 +1335,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
wasm_i32x4_extract_lane(accv, 3)));
}
#elif defined(__AVX2__) || defined(__AVX__)
- block_q8_1_x4 * restrict y4 = vy;
-#ifdef __AVX2__
- const bool pack = true;
-#else
- const bool pack = false;
-#endif
for (int i = 0; i < nb; i++) {
- int i4 = i/4, ir = i%4;
// Load elements into 4 AVX vectors
__m256 v0 = _mm256_loadu_ps( x );
__m256 v1 = _mm256_loadu_ps( x + 8 );
@@ -1418,11 +1357,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
// Quantize these floats
const float d = max_scalar / 127.f;
- if (pack && i < nb4) {
- y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
- } else {
- y[i].d = GGML_FP32_TO_FP16(d);
- }
+ y[i].d = GGML_FP32_TO_FP16(d);
const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;
const __m256 mul = _mm256_set1_ps( id );
@@ -1446,11 +1381,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
#if defined(__AVX2__)
// Compute the sum of the quants and set y[i].s
- if (i < nb4) {
- y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))));
- } else {
- y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))));
- }
+ y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))));
// Convert int32 to int16
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
@@ -1464,11 +1395,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
i0 = _mm256_permutevar8x32_epi32( i0, perm );
- if (i < nb4) {
- _mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0);
- } else {
- _mm256_storeu_si256((__m256i *)y[i].qs, i0);
- }
+ _mm256_storeu_si256((__m256i *)y[i].qs, i0);
#else
// Since we don't have in AVX some necessary functions,
// we split the registers in half and call AVX2 analogs from SSE