summaryrefslogtreecommitdiff
path: root/ggml/src/iqk/iqk_gemm_1bit.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/iqk/iqk_gemm_1bit.cpp')
-rw-r--r--ggml/src/iqk/iqk_gemm_1bit.cpp159
1 files changed, 158 insertions, 1 deletions
diff --git a/ggml/src/iqk/iqk_gemm_1bit.cpp b/ggml/src/iqk/iqk_gemm_1bit.cpp
index 05196c1d..770fbf2c 100644
--- a/ggml/src/iqk/iqk_gemm_1bit.cpp
+++ b/ggml/src/iqk/iqk_gemm_1bit.cpp
@@ -1607,6 +1607,162 @@ static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const Da
}
#endif
+inline float convert_to_q8_k_r8(int k, int d0, const __m256i * qx, const int16_t * scales, uint32_t * block, int8_t * q8_k) {
+ auto max_i16 = _mm256_setzero_si256();
+ for (int ib32 = 0; ib32 < 8; ++ib32) {
+ auto q16_l = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(qx[ib32]));
+ auto q16_h = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(qx[ib32], 1));
+ q16_l = _mm256_mullo_epi16(q16_l, _mm256_set1_epi16(scales[2*ib32+0]));
+ q16_h = _mm256_mullo_epi16(q16_h, _mm256_set1_epi16(scales[2*ib32+1]));
+ max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(q16_l, q16_l));
+ max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(q16_h, q16_h));
+ }
+ auto max_q32 = _mm256_cvtepi16_epi32(_mm_max_epi16(_mm256_castsi256_si128(max_i16), _mm256_extracti128_si256(max_i16, 1)));
+ auto imax4 = _mm_max_epi32(_mm256_castsi256_si128(max_q32), _mm256_extracti128_si256(max_q32, 1));
+ auto max4 = _mm_cvtepi32_ps(imax4);
+ max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
+ max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
+ bool needs_scaling = true;
+ float dnew = _mm_cvtss_f32(max4) / d0;
+ if (dnew < 1.f) {
+ dnew = 1.f; needs_scaling = false;
+ }
+ auto scale = _mm256_set1_ps(std::abs(dnew) > 1e-9f ? 1/dnew : 0.f);
+ for (int ib32 = 0; ib32 < 8; ++ib32) {
+ auto q16_l = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(qx[ib32]));
+ auto q16_h = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(qx[ib32], 1));
+ q16_l = _mm256_mullo_epi16(q16_l, _mm256_set1_epi16(scales[2*ib32+0]));
+ q16_h = _mm256_mullo_epi16(q16_h, _mm256_set1_epi16(scales[2*ib32+1]));
+ if (needs_scaling) {
+ auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_l));
+ auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_l, 1));
+ auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_h));
+ auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_h, 1));
+ i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST));
+ i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST));
+ i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST));
+ i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i3)), _MM_ROUND_NEAREST));
+ i0 = _mm256_packs_epi32(i0, i1);
+ i2 = _mm256_packs_epi32(i2, i3);
+ i0 = _mm256_packs_epi16(i0, i2);
+ i0 = _mm256_permutevar8x32_epi32(i0, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7));
+ _mm256_storeu_si256((__m256i *)block, i0);
+ } else {
+ // 0, 1, 2, 3, 4, 5, 6, 7, 8, 16, 17, 18, 19, 20, 21, 22, 23, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31
+ auto i0 = _mm256_packs_epi16(q16_l, q16_h);
+ auto i0_l = _mm256_castsi256_si128(i0);
+ auto i0_h = _mm256_extracti128_si256(i0, 1);
+ _mm_storeu_si128((__m128i *)block+0, _mm_unpacklo_epi64(i0_l, i0_h));
+ _mm_storeu_si128((__m128i *)block+1, _mm_unpackhi_epi64(i0_l, i0_h));
+ }
+ auto qs = (uint32_t *)q8_k + 64*ib32;
+ for (int l = 0; l < 8; ++l) {
+ qs[8*l + k] = block[l];
+ }
+ }
+ return dnew;
+}
+
+void iqk_convert_iq1_s_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+
+ int nb = n/QK_K;
+
+ const block_iq1_s * x8[8];
+
+ block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
+
+ int16_t ls[16];
+
+ uint32_t block[8];
+
+ __m256i qx[8];
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int k = 0; k < 8; ++k) x8[k] = (const block_iq1_s *)((const char *)vx + (ix + k)*bx);
+ for (int i = 0; i < nb; ++i) {
+ for (int k = 0; k < 8; ++k) {
+ float d = 0.125f * GGML_FP16_TO_FP32(x8[k][i].d);
+ auto qs = x8[k][i].qs;
+ auto qh = x8[k][i].qh;
+ __m256i value;
+ for (int ib32 = 0; ib32 < 8; ++ib32) {
+ ls[2*ib32 + 0] = (2*((qh[ib32] >> 12) & 7) + 1);
+ ls[2*ib32 + 1] = ls[2*ib32 + 0];
+ value = _mm256_set_epi64x(iq1s_grid[qs[3] | ((qh[ib32] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib32] << 2) & 0x700)],
+ iq1s_grid[qs[1] | ((qh[ib32] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib32] << 8) & 0x700)]);
+ value = _mm256_slli_epi16(_mm256_add_epi8(value, _mm256_set1_epi8(1)), 3);
+ int8_t delta = qh[ib32] & 0x8000 ? -9 : -7;
+ value = _mm256_add_epi8(value, _mm256_set1_epi8(delta));
+ qx[ib32] = value;
+ qs += 4;
+ }
+ float dnew = convert_to_q8_k_r8(k, 126, qx, ls, block, y[i].qs);
+ y[i].d[k] = GGML_FP32_TO_FP16(d*dnew);
+ }
+ }
+ y += nb;
+ }
+}
+
+void iqk_convert_iq1_m_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
+ GGML_ASSERT(n%QK_K == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
+
+ int nb = n/QK_K;
+
+ const block_iq1_m * x8[8];
+
+ block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
+
+ int16_t ls[16];
+
+ uint32_t block[8];
+
+ __m256i qx[8];
+
+ auto mask = _mm256_setr_epi32(0x00000008, 0x00000008, 0x00000080, 0x00000080, 0x00080000, 0x00080000, 0x00800000, 0x00800000);
+
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ for (int k = 0; k < 8; ++k) x8[k] = (const block_iq1_m *)((const char *)vx + (ix + k)*bx);
+ for (int i = 0; i < nb; ++i) {
+ for (int k = 0; k < 8; ++k) {
+ const uint16_t * sc = (const uint16_t *)x8[k][i].scales;
+ iq1m_scale_t scale;
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+ float d = 0.125f * GGML_FP16_TO_FP32(scale.f16);
+ auto qs = x8[k][i].qs;
+ auto qh = x8[k][i].qh;
+ __m256i value;
+ for (int ib32 = 0; ib32 < 8; ++ib32) {
+ ls[2*ib32 + 0] = (2*((sc[ib32/2] >> (6*(ib32%2)+0)) & 0x7) + 1);
+ ls[2*ib32 + 1] = (2*((sc[ib32/2] >> (6*(ib32%2)+3)) & 0x7) + 1);
+ value = _mm256_set_epi64x(iq1s_grid[qs[3] | ((qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | ((qh[1] << 8) & 0x700)],
+ iq1s_grid[qs[1] | ((qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | ((qh[0] << 8) & 0x700)]);
+ value = _mm256_slli_epi16(_mm256_add_epi8(value, _mm256_set1_epi8(1)), 3);
+
+ auto delta_mask = _mm256_cmpeq_epi32(_mm256_and_si256(_mm256_set1_epi32(qh[0] | qh[1] << 16), mask), mask);
+ auto delta = _mm256_add_epi8(_mm256_set1_epi8(7), _mm256_and_si256(delta_mask, _mm256_set1_epi8(2)));
+ qx[ib32] = _mm256_sub_epi8(value, delta);
+
+ //int64_t delta1 = qh[0] & 0x08 ? 0x0909090909090909 : 0x0707070707070707;
+ //int64_t delta2 = qh[0] & 0x80 ? 0x0909090909090909 : 0x0707070707070707;
+ //int64_t delta3 = qh[1] & 0x08 ? 0x0909090909090909 : 0x0707070707070707;
+ //int64_t delta4 = qh[1] & 0x80 ? 0x0909090909090909 : 0x0707070707070707;
+ //value = _mm256_sub_epi8(value, _mm256_set_epi64x(delta4, delta3, delta2, delta1));
+ //qx[ib32] = value;
+ qs += 4;
+ qh += 2;
+ }
+ float dnew = convert_to_q8_k_r8(k, 126, qx, ls, block, y[i].qs);
+ y[i].d[k] = GGML_FP32_TO_FP16(d*dnew);
+ }
+ }
+ y += nb;
+ }
+}
+
void iqk_convert_iq1_s_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);
@@ -1722,7 +1878,8 @@ bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t,
bool iqk_convert_1bit_q80_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) {
if (n%QK_K != 0 || nrc_x%8 != 0) return false;
switch (ggml_type(type)) {
- case GGML_TYPE_IQ1_S: iqk_convert_iq1_s_q8_0_r8(n, vx, bx, vy, nrc_x); break;
+ case GGML_TYPE_IQ1_S: iqk_convert_iq1_s_q8_k_r8(n, vx, bx, vy, nrc_x); break;
+ case GGML_TYPE_IQ1_M: iqk_convert_iq1_m_q8_k_r8(n, vx, bx, vy, nrc_x); break;
default: return false;
}
return true;