diff options
Diffstat (limited to 'ggml/src/iqk/iqk_quantize.cpp')
-rw-r--r-- | ggml/src/iqk/iqk_quantize.cpp | 290 |
1 files changed, 287 insertions, 3 deletions
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 24b49d89..7b777a1f 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -2967,6 +2967,103 @@ void iqk_quantize_row_q8_K128(const float * x, void * vy, int64_t k) { } #endif } +// TODO: merge this with the above template +void iqk_quantize_row_q8_KV(const float * x, void * vy, int64_t k) { + assert(k % 32 == 0); + auto dptr = (float *)vy; + auto q8 = (int8_t *)(dptr + 2); +#ifdef __AVX2__ + const __m256 signBit = _mm256_set1_ps(-0.0f); + const __m256i perm = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7); + __m256 maxAbs = _mm256_setzero_ps(); + for (int ib = 0; ib < k/8; ++ib) { + const __m256 v = _mm256_loadu_ps(x + 8*ib); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps(signBit, v)); + } + const float maxScalar = hmax_f32_8(maxAbs); + if (!maxScalar) { + dptr[0] = dptr[1] = 0; + std::memset(q8, 0, k*sizeof(int8_t)); + return; + } + dptr[0] = maxScalar / 127.f; + auto mul = _mm256_set1_ps(1/dptr[0]); + auto isum = _mm256_setzero_si256(); + for (int i = 0; i < k/32; i++) { + __m256 v0 = _mm256_mul_ps(mul, _mm256_loadu_ps(x + 32*i + 0)); + __m256 v1 = _mm256_mul_ps(mul, _mm256_loadu_ps(x + 32*i + 8)); + __m256 v2 = _mm256_mul_ps(mul, _mm256_loadu_ps(x + 32*i + 16)); + __m256 v3 = _mm256_mul_ps(mul, _mm256_loadu_ps(x + 32*i + 24)); + v0 = _mm256_round_ps(v0, _MM_ROUND_NEAREST); + v1 = _mm256_round_ps(v1, _MM_ROUND_NEAREST); + v2 = _mm256_round_ps(v2, _MM_ROUND_NEAREST); + v3 = _mm256_round_ps(v3, _MM_ROUND_NEAREST); + __m256i i0 = _mm256_cvtps_epi32(v0); + __m256i i1 = _mm256_cvtps_epi32(v1); + __m256i i2 = _mm256_cvtps_epi32(v2); + __m256i i3 = _mm256_cvtps_epi32(v3); + isum = _mm256_add_epi32(isum, _mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); + i0 = _mm256_packs_epi32( i0, i1 ); + i2 = _mm256_packs_epi32( i2, i3 ); + i0 = _mm256_packs_epi16( i0, i2 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + _mm256_storeu_si256((__m256i *)q8, i0); + q8 += 32; + } + auto iptr = (int32_t *)(dptr + 1); + iptr[0] = hsum_i32_8(isum); +#elif defined __ARM_NEON + int32x4_t ival[8]; + auto vmax = vdupq_n_f32(0.f); + for (int j = 0; j < k; j += 4) { + vmax = vmaxq_f32(vmax, vabsq_f32(vld1q_f32(x + j))); + } + auto smax = vmaxvq_f32(vmax); + if (!smax) { + dptr[0] = dptr[1] = 0; + std::memset(q8, 0, k*sizeof(int8_t)); + return; + } + dptr[0] = smax/127; + auto vid = vdupq_n_f32(1/dptr[0]); + auto isum = vdupq_n_s32(0); + for (int ib = 0; ib < k/32; ++ib) { + auto xb = x + 32*ib; + for (int k = 0; k < 8; ++k) { + auto val = vld1q_f32(xb + 4*k); + ival[k] = vcvtnq_s32_f32(vmulq_f32(val, vid)); + isum = vaddq_s32(isum, ival[k]); + } + for (int k = 0; k < 4; ++k) { + auto i16 = vcombine_s16(vmovn_s32(ival[2*k+0]), vmovn_s32(ival[2*k+1])); + vst1_s8(q8, vmovn_s16(i16)); + q8 += 8; + } + } + auto iptr = (int32_t *)(dptr + 1); + iptr[0] = vaddvq_s32(isum); +#else + float amax = 0; + for (int j = 0; j < k; ++j) { + float ax = std::abs(x[j]); + amax = std::max(amax, ax); + } + if (!amax) { + dptr[0] = dptr[1] = 0; + std::memset(q8, 0, k*sizeof(int8_t)); + return; + } + dptr[0] = amax/127; + float id = 1/dptr[0]; + int isum = 0; + for (int i = 0; i < k; i++) { + q8[i] = nearest_int(id*x[i]); + isum += q8[i]; + } + auto iptr = (int32_t *)(dptr + 1); + iptr[0] = isum; +#endif +} } void quantize_row_q8_K128(const float * x, void * vy, int64_t k) { @@ -3886,7 +3983,7 @@ static void repack_q8_0(int nrows, int n_per_row, const block_q8_0 * x, block_q8 #ifdef HAVE_FANCY_SIMD static void modify_q8_0_r8(int64_t k, char * cy) { - auto y = (block_iq4_nl_r8 *)cy; + auto y = (block_q8_0_r8 *)cy; int nb = k/(32*8); for (int ib = 0; ib < nb; ++ib) { for (int l = 0; l < 4; ++l) { @@ -5413,6 +5510,150 @@ void vec_dot_q8_k_r8_q8_k(int n, float * s, size_t bs, const void * vx, size_t b } // +// ========================================= q8_KV_r8 +// + +void quantize_row_q8_KV_r8_ref(const float * x, void * y, int64_t k) { + quantize_q8_KV_r8(x, y, 8, k/8, nullptr); +} + +void quantize_row_q8_KV_r8(const float * x, void * y, int64_t k) { + quantize_q8_KV_r8(x, y, 8, k/8, nullptr); +} + +static void repack_q8_KV(int nrows, int n_per_row, const char * cx, char * cy, [[maybe_unused]] bool online) { + GGML_ASSERT(nrows%8 == 0); + GGML_ASSERT(n_per_row%16 == 0); + auto row_size_x = ggml_row_size(GGML_TYPE_Q8_KV, n_per_row); + auto row_size_y = ggml_row_size(GGML_TYPE_Q8_KV_R8, n_per_row); + const int8_t * x8[8]; +#ifdef __ARM_NEON + int8x16x2_t m0, m1, m2, m3; +#endif + for (int row = 0; row < nrows; row += 8) { + auto dy = (float *)cy; + auto qy = (int8_t *)(dy + 8); + for (int k = 0; k < 8; ++k) { + auto dx = (const float *)(cx + k*row_size_x); + dy[k] = dx[0]; + x8[k] = (const int8_t *)(dx + 2); + } + for (int ib = 0; ib < n_per_row/16; ++ib) { +#ifdef __AVX2__ +#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) + auto m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4]+ib), _mm_loadu_si128((const __m128i *)x8[0]+ib)); + auto m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5]+ib), _mm_loadu_si128((const __m128i *)x8[1]+ib)); + auto m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6]+ib), _mm_loadu_si128((const __m128i *)x8[2]+ib)); + auto m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7]+ib), _mm_loadu_si128((const __m128i *)x8[3]+ib)); + auto t0 = _mm256_unpacklo_epi32(m0, m1); + auto t1 = _mm256_unpacklo_epi32(m2, m3); + auto t2 = _mm256_unpackhi_epi32(m0, m1); + auto t3 = _mm256_unpackhi_epi32(m2, m3); + m0 = _mm256_unpacklo_epi64(t0, t1); + m1 = _mm256_unpackhi_epi64(t0, t1); + m2 = _mm256_unpacklo_epi64(t2, t3); + m3 = _mm256_unpackhi_epi64(t2, t3); +#ifdef HAVE_FANCY_SIMD + if (online) { + m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127)); + m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127)); + m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127)); + m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127)); + } +#endif + _mm256_storeu_si256((__m256i *)qy + 4*ib+0, m0); + _mm256_storeu_si256((__m256i *)qy + 4*ib+1, m1); + _mm256_storeu_si256((__m256i *)qy + 4*ib+2, m2); + _mm256_storeu_si256((__m256i *)qy + 4*ib+3, m3); +#elif defined __ARM_NEON + m0.val[0] = vld1q_s8(x8[0]+16*ib); m0.val[1] = vld1q_s8(x8[4]+16*ib); + m1.val[0] = vld1q_s8(x8[1]+16*ib); m1.val[1] = vld1q_s8(x8[5]+16*ib); + m2.val[0] = vld1q_s8(x8[2]+16*ib); m2.val[1] = vld1q_s8(x8[6]+16*ib); + m3.val[0] = vld1q_s8(x8[3]+16*ib); m3.val[1] = vld1q_s8(x8[7]+16*ib); + auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0])); + auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0])); + m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1])); + row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1])); + m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + vst1q_s8_x2(qy + 0 + 128*ib, m0); + vst1q_s8_x2(qy + 32 + 128*ib, m1); + vst1q_s8_x2(qy + 64 + 128*ib, m2); + vst1q_s8_x2(qy + 96 + 128*ib, m3); +#else + // TODO + for (int l = 0; l < 4; ++l) { + for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) { + y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0]; + y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16]; + } + } +#endif + + } + cx += 8*row_size_x; + cy += online ? 8*row_size_x : 8*row_size_y; + //So, if we are run-time-repacking (online = true) we don't want to change the stride, so we just leave some unused space at the end of each row + } +} +#ifdef HAVE_FANCY_SIMD +static void modify_q8_KV_r8(int64_t k, char * cy) { + int8_t * q8 = (int8_t *)(cy + 8*sizeof(float)); + for (int j = 0; j < k; ++j) q8[j] += 127; +} +#endif + +size_t quantize_q8_KV_r8(const float * src, void * dst, int64_t nrows, int64_t n_per_row, [[maybe_unused]] const float * imatrix) { + GGML_ASSERT(nrows%8 == 0); + GGML_ASSERT(n_per_row%16 == 0); + char * qcur = (char *)dst; + auto row_size_0 = ggml_row_size(GGML_TYPE_Q8_KV, n_per_row); + auto row_size_1 = ggml_row_size(GGML_TYPE_Q8_KV_R8, n_per_row); + std::vector<char> qtmp(8*row_size_0); + for (int row = 0; row < nrows; row += 8) { + quantize_q8_KV(src, (void *)qtmp.data(), 8, n_per_row, imatrix); + repack_q8_KV(8, n_per_row, qtmp.data(), qcur, false); + qcur += 8*row_size_1; + src += 8*n_per_row; + } + return nrows*row_size_1; +} + +void dequantize_row_q8_KV_r8(const void * vx, float * y, int64_t k) { + auto n_per_row = k/8; + float * y8[8]; + for (int k = 0; k < 8; ++k) y8[k] = y + n_per_row*k; + auto dptr = (const float *)vx; + auto q8 = (const int8_t *)(dptr + 8); + for (int ib = 0; ib < n_per_row/16; ++ib) { + for (int k = 0; k < 8; ++k) { + for (int l = 0; l < 4; ++l) { + for (int i = 0; i < 4; ++i) y8[k][16*ib + 4*l + i] = dptr[k] * q8[128*ib + 32*l + 4*k + i]; + } + } + } +} + +void vec_dot_q8_KV_r8_q8_KV(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q8_KV_R8, vx, 0, GGML_TYPE_Q8_KV, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + +// // ========================================= bf16_r4 // namespace { @@ -6450,6 +6691,47 @@ void vec_dot_iq1_m_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t GGML_UNUSED(by); } +void quantize_row_q8_KV(const float * x, void * vy, int64_t k) { + iqk_quantize_row_q8_KV(x, vy, k); +} + +void quantize_row_q8_KV_ref(const float * x, void * y, int64_t k) { + quantize_row_q8_KV(x, y, k); +} + +size_t quantize_q8_KV(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + (void)imatrix; + auto row_size = ggml_row_size(GGML_TYPE_Q8_KV, n_per_row); + auto q = (char *)dst; + for (int row = 0; row < nrows; ++row) { + quantize_row_q8_KV(src, q, n_per_row); + src += n_per_row; + q += row_size; + } + return row_size*nrows; +} + +void dequantize_row_q8_KV(const void * x, float * y, int64_t k) { + auto dptr = (const float *)x; + float d = dptr[0]; + auto q8 = (const int8_t *)(dptr + 2); + for (int j = 0; j < k; ++j) y[j] = d * q8[j]; +} + +void vec_dot_q8_KV_q8_KV(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q8_KV, vx, 0, GGML_TYPE_Q8_KV, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + + //================================================ namespace { @@ -6472,8 +6754,9 @@ bool iqk_modify_tensor(struct ggml_tensor * tensor) { { GGML_TYPE_Q4_0_R8, {modify_q4_0_r8, 8} }, #endif #ifdef HAVE_FANCY_SIMD - { GGML_TYPE_Q8_0_R8, {modify_q8_0_r8, 8} }, - { GGML_TYPE_Q8_K_R8, {modify_q8_k_r8, 8} }, + { GGML_TYPE_Q8_0_R8, {modify_q8_0_r8, 8} }, + { GGML_TYPE_Q8_K_R8, {modify_q8_k_r8, 8} }, + { GGML_TYPE_Q8_KV_R8, {modify_q8_KV_r8, 8} }, #endif }; auto it = k_mod_map.find(tensor->type); @@ -6532,6 +6815,7 @@ void iqk_repack_tensor(struct ggml_tensor * tensor) { { GGML_TYPE_Q6_0, { GGML_TYPE_Q6_0_R4, 4, (Repack::repack_func)repack_q6_0} }, { GGML_TYPE_Q8_0, { GGML_TYPE_Q8_0_R8, 8, (Repack::repack_func)repack_q8_0} }, { GGML_TYPE_Q8_K, { GGML_TYPE_Q8_K_R8, 8, (Repack::repack_func)repack_q8_k} }, + { GGML_TYPE_Q8_KV, { GGML_TYPE_Q8_KV_R8, 8, (Repack::repack_func)repack_q8_KV} }, #ifdef __AVX512BF16__ { GGML_TYPE_BF16, { GGML_TYPE_BF16_R16, 16, (Repack::repack_func)repack_bf16<ggml_bf16_t>}}, { GGML_TYPE_F16, { GGML_TYPE_BF16_R16, 16, (Repack::repack_func)repack_bf16<ggml_half>} }, |