summaryrefslogtreecommitdiff
path: root/ggml-quants.c
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-quants.c')
-rw-r--r--ggml-quants.c241
1 files changed, 168 insertions, 73 deletions
diff --git a/ggml-quants.c b/ggml-quants.c
index 740be6dc..a48eda73 100644
--- a/ggml-quants.c
+++ b/ggml-quants.c
@@ -14,26 +14,6 @@
//
#include <arm_neon.h>
-#if !defined(__aarch64__)
-inline static int32_t vaddvq_s16(int16x8_t v) {
- return
- (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
- (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
- (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
- (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
-}
-
-inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
- int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
- int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
- return vcombine_s16(a0, b0);
-}
-
-inline static int32_t vaddvq_s32(int32x4_t v) {
- return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
-}
-#endif
-
#else
#ifdef __wasm_simd128__
@@ -47,13 +27,15 @@ inline static int32_t vaddvq_s32(int32x4_t v) {
#if defined(_MSC_VER) || defined(__MINGW32__)
#include <intrin.h>
#else
-#if !defined(__riscv) && !defined(__s390__)
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
+#if !defined(__riscv)
#include <immintrin.h>
#endif
#endif
#endif
#endif
#endif
+#endif
#ifdef __riscv_v_intrinsic
#include <riscv_vector.h>
@@ -61,6 +43,7 @@ inline static int32_t vaddvq_s32(int32x4_t v) {
#undef MIN
#undef MAX
+
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
@@ -283,9 +266,31 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128
#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
#if defined(__ARM_NEON)
-
#if !defined(__aarch64__)
+// 64-bit compatibility
+
+// vaddvq_s16
+// vpaddq_s16
+// vaddvq_s32
+// vaddvq_f32
+// vmaxvq_f32
+// vcvtnq_s32_f32
+
+inline static int32_t vaddvq_s16(int16x8_t v) {
+ return
+ (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
+ (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
+ (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
+ (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
+}
+
+inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
+ int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
+ int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
+ return vcombine_s16(a0, b0);
+}
+
inline static int32_t vaddvq_s32(int32x4_t v) {
return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
}
@@ -311,6 +316,96 @@ inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
return res;
}
+// vld1q_s16_x2
+// vld1q_u8_x2
+// vld1q_u8_x4
+// vld1q_s8_x2
+// vld1q_s8_x4
+// TODO: double-check these work correctly
+
+typedef struct ggml_int16x8x2_t {
+ int16x8_t val[2];
+} ggml_int16x8x2_t;
+
+inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) {
+ ggml_int16x8x2_t res;
+
+ res.val[0] = vld1q_s16(ptr + 0);
+ res.val[1] = vld1q_s16(ptr + 8);
+
+ return res;
+}
+
+typedef struct ggml_uint8x16x2_t {
+ uint8x16_t val[2];
+} ggml_uint8x16x2_t;
+
+inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) {
+ ggml_uint8x16x2_t res;
+
+ res.val[0] = vld1q_u8(ptr + 0);
+ res.val[1] = vld1q_u8(ptr + 16);
+
+ return res;
+}
+
+typedef struct ggml_uint8x16x4_t {
+ uint8x16_t val[4];
+} ggml_uint8x16x4_t;
+
+inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) {
+ ggml_uint8x16x4_t res;
+
+ res.val[0] = vld1q_u8(ptr + 0);
+ res.val[1] = vld1q_u8(ptr + 16);
+ res.val[2] = vld1q_u8(ptr + 32);
+ res.val[3] = vld1q_u8(ptr + 48);
+
+ return res;
+}
+
+typedef struct ggml_int8x16x2_t {
+ int8x16_t val[2];
+} ggml_int8x16x2_t;
+
+inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) {
+ ggml_int8x16x2_t res;
+
+ res.val[0] = vld1q_s8(ptr + 0);
+ res.val[1] = vld1q_s8(ptr + 16);
+
+ return res;
+}
+
+typedef struct ggml_int8x16x4_t {
+ int8x16_t val[4];
+} ggml_int8x16x4_t;
+
+inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
+ ggml_int8x16x4_t res;
+
+ res.val[0] = vld1q_s8(ptr + 0);
+ res.val[1] = vld1q_s8(ptr + 16);
+ res.val[2] = vld1q_s8(ptr + 32);
+ res.val[3] = vld1q_s8(ptr + 48);
+
+ return res;
+}
+
+#else
+
+#define ggml_int16x8x2_t int16x8x2_t
+#define ggml_uint8x16x2_t uint8x16x2_t
+#define ggml_uint8x16x4_t uint8x16x4_t
+#define ggml_int8x16x2_t int8x16x2_t
+#define ggml_int8x16x4_t int8x16x4_t
+
+#define ggml_vld1q_s16_x2 vld1q_s16_x2
+#define ggml_vld1q_u8_x2 vld1q_u8_x2
+#define ggml_vld1q_u8_x4 vld1q_u8_x4
+#define ggml_vld1q_s8_x2 vld1q_s8_x2
+#define ggml_vld1q_s8_x4 vld1q_s8_x4
+
#endif
#endif
@@ -3557,7 +3652,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
const int32x4_t vzero = vdupq_n_s32(0);
#endif
- int8x16x2_t q2bytes;
+ ggml_int8x16x2_t q2bytes;
uint8_t aux[16];
float sum = 0;
@@ -3576,8 +3671,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
vst1q_u8(aux, scales);
const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
- const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums);
- const int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))};
+ const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
+ const ggml_int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))};
const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
@@ -3605,7 +3700,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
#endif
#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
- q8bytes = vld1q_s8_x2(q8); q8 += 32;\
+ q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
MULTIPLY_ACCUM_WITH_SCALE((index));
@@ -3613,9 +3708,9 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
for (int j = 0; j < QK_K/128; ++j) {
- const uint8x16x2_t q2bits = vld1q_u8_x2(q2); q2 += 32;
+ const ggml_uint8x16x2_t q2bits = ggml_vld1q_u8_x2(q2); q2 += 32;
- int8x16x2_t q8bytes = vld1q_s8_x2(q8); q8 += 32;
+ ggml_int8x16x2_t q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
MULTIPLY_ACCUM_WITH_SCALE(0);
@@ -3949,7 +4044,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
const int32x4_t vzero = vdupq_n_s32(0);
#endif
- int8x16x4_t q2bytes;
+ ggml_int8x16x4_t q2bytes;
uint32_t aux32[2];
const uint8_t * scales = (const uint8_t *)aux32;
@@ -3974,7 +4069,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
const uint8x16_t q2bits = vld1q_u8(q2);
- const int8x16x4_t q8bytes = vld1q_s8_x4(q8);
+ const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8);
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3));
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3));
@@ -4238,7 +4333,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
const uint8x16_t m3 = vshlq_n_u8(m0, 3);
const int8_t m32 = 32;
- int8x16x4_t q3bytes;
+ ggml_int8x16x4_t q3bytes;
float sum = 0;
@@ -4250,9 +4345,9 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
const uint8_t * restrict qh = x[i].hmask;
const int8_t * restrict q8 = y[i].qs;
- uint8x16x2_t qhbits = vld1q_u8_x2(qh);
+ ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
- uint8x16x4_t q3h;
+ ggml_uint8x16x4_t q3h;
int32_t isum = 0;
@@ -4268,9 +4363,9 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
for (int j = 0; j < QK_K/128; ++j) {
- const uint8x16x2_t q3bits = vld1q_u8_x2(q3); q3 += 32;
- const int8x16x4_t q8bytes_1 = vld1q_s8_x4(q8); q8 += 64;
- const int8x16x4_t q8bytes_2 = vld1q_s8_x4(q8); q8 += 64;
+ const ggml_uint8x16x2_t q3bits = ggml_vld1q_u8_x2(q3); q3 += 32;
+ const ggml_int8x16x4_t q8bytes_1 = ggml_vld1q_s8_x4(q8); q8 += 64;
+ const ggml_int8x16x4_t q8bytes_2 = ggml_vld1q_s8_x4(q8); q8 += 64;
q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2);
q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2);
@@ -4772,7 +4867,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
const uint8x16_t m3b = vdupq_n_u8(0x3);
const uint8x16_t mh = vdupq_n_u8(4);
- int8x16x4_t q3bytes;
+ ggml_int8x16x4_t q3bytes;
uint16_t aux16[2];
int8_t * scales = (int8_t *)aux16;
@@ -4781,11 +4876,11 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
for (int i = 0; i < nb; ++i) {
- uint8x16x4_t q3h;
+ ggml_uint8x16x4_t q3h;
const uint8x8_t hbits = vld1_u8(x[i].hmask);
const uint8x16_t q3bits = vld1q_u8(x[i].qs);
- const int8x16x4_t q8bytes = vld1q_s8_x4(y[i].qs);
+ const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(y[i].qs);
const uint16_t a = *(const uint16_t *)x[i].scales;
aux16[0] = a & 0x0f0f;
@@ -5134,8 +5229,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
const int32x4_t mzero = vdupq_n_s32(0);
#endif
- int8x16x2_t q4bytes;
- int8x16x2_t q8bytes;
+ ggml_int8x16x2_t q4bytes;
+ ggml_int8x16x2_t q8bytes;
float sumf = 0;
@@ -5170,17 +5265,17 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
for (int j = 0; j < QK_K/64; ++j) {
- const uint8x16x2_t q4bits = vld1q_u8_x2(q4); q4 += 32;
+ const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32;
#ifdef __ARM_FEATURE_DOTPROD
- q8bytes = vld1q_s8_x2(q8); q8 += 32;
+ q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
sumi1 += vaddvq_s32(p1) * scales[2*j+0];
- q8bytes = vld1q_s8_x2(q8); q8 += 32;
+ q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
@@ -5188,7 +5283,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
sumi2 += vaddvq_s32(p2) * scales[2*j+1];
#else
- q8bytes = vld1q_s8_x2(q8); q8 += 32;
+ q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
@@ -5197,7 +5292,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0];
- q8bytes = vld1q_s8_x2(q8); q8 += 32;
+ q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
@@ -5512,8 +5607,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
float sumf = 0;
- int8x16x2_t q4bytes;
- int8x16x4_t q8bytes;
+ ggml_int8x16x2_t q4bytes;
+ ggml_int8x16x4_t q8bytes;
float sum_mins = 0.f;
@@ -5534,10 +5629,10 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
const float d = y[i].d * (float)x[i].d[0];
- const uint8x16x2_t q4bits = vld1q_u8_x2(q4);
+ const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4);
#ifdef __ARM_FEATURE_DOTPROD
- q8bytes = vld1q_s8_x4(q8);
+ q8bytes = ggml_vld1q_s8_x4(q8);
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
@@ -5551,7 +5646,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
#else
- q8bytes = vld1q_s8_x4(q8);
+ q8bytes = ggml_vld1q_s8_x4(q8);
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
@@ -5785,7 +5880,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
const int32x4_t mzero = vdupq_n_s32(0);
#endif
- int8x16x4_t q5bytes;
+ ggml_int8x16x4_t q5bytes;
float sumf = 0;
@@ -5815,16 +5910,16 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict q8 = y[i].qs;
- uint8x16x2_t qhbits = vld1q_u8_x2(qh);
+ ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
- uint8x16x4_t q5h;
+ ggml_uint8x16x4_t q5h;
int32_t sumi = 0;
for (int j = 0; j < QK_K/64; ++j) {
- const uint8x16x2_t q5bits = vld1q_u8_x2(q5); q5 += 32;
- const int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64;
+ const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5); q5 += 32;
+ const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
@@ -6218,8 +6313,8 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
const int32x4_t mzero = vdupq_n_s32(0);
#endif
- int8x16x4_t q5bytes;
- uint8x16x4_t q5h;
+ ggml_int8x16x4_t q5bytes;
+ ggml_uint8x16x4_t q5h;
float sumf = 0;
@@ -6234,8 +6329,8 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
const uint8x8_t qhbits = vld1_u8(qh);
- const uint8x16x2_t q5bits = vld1q_u8_x2(q5);
- const int8x16x4_t q8bytes = vld1q_s8_x4(q8);
+ const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5);
+ const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8);
const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1));
q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4));
@@ -6511,8 +6606,8 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
const uint8x16_t mone = vdupq_n_u8(3);
- int8x16x4_t q6bytes;
- uint8x16x4_t q6h;
+ ggml_int8x16x4_t q6bytes;
+ ggml_uint8x16x4_t q6h;
for (int i = 0; i < nb; ++i) {
@@ -6524,9 +6619,9 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
const int8_t * restrict scale = x[i].scales;
- const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums);
+ const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
const int8x16_t scales = vld1q_s8(scale);
- const int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))};
+ const ggml_int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))};
const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
@@ -6538,9 +6633,9 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
for (int j = 0; j < QK_K/128; ++j) {
- uint8x16x2_t qhbits = vld1q_u8_x2(qh); qh += 32;
- uint8x16x4_t q6bits = vld1q_u8_x4(q6); q6 += 64;
- int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64;
+ ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); qh += 32;
+ ggml_uint8x16x4_t q6bits = ggml_vld1q_u8_x4(q6); q6 += 64;
+ ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
@@ -6583,7 +6678,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
scale += 2;
#endif
- q8bytes = vld1q_s8_x4(q8); q8 += 64;
+ q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
shifted = vshrq_n_u8(qhbits.val[0], 4);
q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
@@ -6987,8 +7082,8 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
const uint8x16_t mone = vdupq_n_u8(3);
- int8x16x4_t q6bytes;
- uint8x16x4_t q6h;
+ ggml_int8x16x4_t q6bytes;
+ ggml_uint8x16x4_t q6h;
for (int i = 0; i < nb; ++i) {
@@ -7002,9 +7097,9 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
int32_t isum = 0;
- uint8x16_t qhbits = vld1q_u8(qh);
- uint8x16x2_t q6bits = vld1q_u8_x2(q6);
- int8x16x4_t q8bytes = vld1q_s8_x4(q8);
+ uint8x16_t qhbits = vld1q_u8(qh);
+ ggml_uint8x16x2_t q6bits = ggml_vld1q_u8_x2(q6);
+ ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8);
q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4);
uint8x16_t shifted = vshrq_n_u8(qhbits, 2);