From 3d68f364f15778dc326f5024f2e5af1ad6dfddef Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 13 Nov 2023 16:55:52 +0200 Subject: ggml : sync (im2col, GPU conv, 32-bit arm compat) (#4060) ggml-ci --- ggml-quants.c | 241 ++++++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 168 insertions(+), 73 deletions(-) (limited to 'ggml-quants.c') diff --git a/ggml-quants.c b/ggml-quants.c index 740be6dc..a48eda73 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -14,26 +14,6 @@ // #include -#if !defined(__aarch64__) -inline static int32_t vaddvq_s16(int16x8_t v) { - return - (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) + - (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) + - (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) + - (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7); -} - -inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) { - int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a)); - int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b)); - return vcombine_s16(a0, b0); -} - -inline static int32_t vaddvq_s32(int32x4_t v) { - return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3); -} -#endif - #else #ifdef __wasm_simd128__ @@ -47,13 +27,15 @@ inline static int32_t vaddvq_s32(int32x4_t v) { #if defined(_MSC_VER) || defined(__MINGW32__) #include #else -#if !defined(__riscv) && !defined(__s390__) +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) +#if !defined(__riscv) #include #endif #endif #endif #endif #endif +#endif #ifdef __riscv_v_intrinsic #include @@ -61,6 +43,7 @@ inline static int32_t vaddvq_s32(int32x4_t v) { #undef MIN #undef MAX + #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b)) @@ -283,9 +266,31 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 #endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) #if defined(__ARM_NEON) - #if !defined(__aarch64__) +// 64-bit compatibility + +// vaddvq_s16 +// vpaddq_s16 +// vaddvq_s32 +// vaddvq_f32 +// vmaxvq_f32 +// vcvtnq_s32_f32 + +inline static int32_t vaddvq_s16(int16x8_t v) { + return + (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) + + (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) + + (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) + + (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7); +} + +inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) { + int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a)); + int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b)); + return vcombine_s16(a0, b0); +} + inline static int32_t vaddvq_s32(int32x4_t v) { return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3); } @@ -311,6 +316,96 @@ inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) { return res; } +// vld1q_s16_x2 +// vld1q_u8_x2 +// vld1q_u8_x4 +// vld1q_s8_x2 +// vld1q_s8_x4 +// TODO: double-check these work correctly + +typedef struct ggml_int16x8x2_t { + int16x8_t val[2]; +} ggml_int16x8x2_t; + +inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) { + ggml_int16x8x2_t res; + + res.val[0] = vld1q_s16(ptr + 0); + res.val[1] = vld1q_s16(ptr + 8); + + return res; +} + +typedef struct ggml_uint8x16x2_t { + uint8x16_t val[2]; +} ggml_uint8x16x2_t; + +inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) { + ggml_uint8x16x2_t res; + + res.val[0] = vld1q_u8(ptr + 0); + res.val[1] = vld1q_u8(ptr + 16); + + return res; +} + +typedef struct ggml_uint8x16x4_t { + uint8x16_t val[4]; +} ggml_uint8x16x4_t; + +inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) { + ggml_uint8x16x4_t res; + + res.val[0] = vld1q_u8(ptr + 0); + res.val[1] = vld1q_u8(ptr + 16); + res.val[2] = vld1q_u8(ptr + 32); + res.val[3] = vld1q_u8(ptr + 48); + + return res; +} + +typedef struct ggml_int8x16x2_t { + int8x16_t val[2]; +} ggml_int8x16x2_t; + +inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) { + ggml_int8x16x2_t res; + + res.val[0] = vld1q_s8(ptr + 0); + res.val[1] = vld1q_s8(ptr + 16); + + return res; +} + +typedef struct ggml_int8x16x4_t { + int8x16_t val[4]; +} ggml_int8x16x4_t; + +inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) { + ggml_int8x16x4_t res; + + res.val[0] = vld1q_s8(ptr + 0); + res.val[1] = vld1q_s8(ptr + 16); + res.val[2] = vld1q_s8(ptr + 32); + res.val[3] = vld1q_s8(ptr + 48); + + return res; +} + +#else + +#define ggml_int16x8x2_t int16x8x2_t +#define ggml_uint8x16x2_t uint8x16x2_t +#define ggml_uint8x16x4_t uint8x16x4_t +#define ggml_int8x16x2_t int8x16x2_t +#define ggml_int8x16x4_t int8x16x4_t + +#define ggml_vld1q_s16_x2 vld1q_s16_x2 +#define ggml_vld1q_u8_x2 vld1q_u8_x2 +#define ggml_vld1q_u8_x4 vld1q_u8_x4 +#define ggml_vld1q_s8_x2 vld1q_s8_x2 +#define ggml_vld1q_s8_x4 vld1q_s8_x4 + #endif #endif @@ -3557,7 +3652,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri const int32x4_t vzero = vdupq_n_s32(0); #endif - int8x16x2_t q2bytes; + ggml_int8x16x2_t q2bytes; uint8_t aux[16]; float sum = 0; @@ -3576,8 +3671,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri vst1q_u8(aux, scales); const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4); - const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums); - const int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}; + const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums); + const ggml_int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}; const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])), vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0]))); const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])), @@ -3605,7 +3700,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri #endif #define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\ - q8bytes = vld1q_s8_x2(q8); q8 += 32;\ + q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\ q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\ q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\ MULTIPLY_ACCUM_WITH_SCALE((index)); @@ -3613,9 +3708,9 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri for (int j = 0; j < QK_K/128; ++j) { - const uint8x16x2_t q2bits = vld1q_u8_x2(q2); q2 += 32; + const ggml_uint8x16x2_t q2bits = ggml_vld1q_u8_x2(q2); q2 += 32; - int8x16x2_t q8bytes = vld1q_s8_x2(q8); q8 += 32; + ggml_int8x16x2_t q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3)); q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3)); MULTIPLY_ACCUM_WITH_SCALE(0); @@ -3949,7 +4044,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri const int32x4_t vzero = vdupq_n_s32(0); #endif - int8x16x4_t q2bytes; + ggml_int8x16x4_t q2bytes; uint32_t aux32[2]; const uint8_t * scales = (const uint8_t *)aux32; @@ -3974,7 +4069,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri const uint8x16_t q2bits = vld1q_u8(q2); - const int8x16x4_t q8bytes = vld1q_s8_x4(q8); + const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3)); q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3)); @@ -4238,7 +4333,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri const uint8x16_t m3 = vshlq_n_u8(m0, 3); const int8_t m32 = 32; - int8x16x4_t q3bytes; + ggml_int8x16x4_t q3bytes; float sum = 0; @@ -4250,9 +4345,9 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri const uint8_t * restrict qh = x[i].hmask; const int8_t * restrict q8 = y[i].qs; - uint8x16x2_t qhbits = vld1q_u8_x2(qh); + ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); - uint8x16x4_t q3h; + ggml_uint8x16x4_t q3h; int32_t isum = 0; @@ -4268,9 +4363,9 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri for (int j = 0; j < QK_K/128; ++j) { - const uint8x16x2_t q3bits = vld1q_u8_x2(q3); q3 += 32; - const int8x16x4_t q8bytes_1 = vld1q_s8_x4(q8); q8 += 64; - const int8x16x4_t q8bytes_2 = vld1q_s8_x4(q8); q8 += 64; + const ggml_uint8x16x2_t q3bits = ggml_vld1q_u8_x2(q3); q3 += 32; + const ggml_int8x16x4_t q8bytes_1 = ggml_vld1q_s8_x4(q8); q8 += 64; + const ggml_int8x16x4_t q8bytes_2 = ggml_vld1q_s8_x4(q8); q8 += 64; q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2); q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2); @@ -4772,7 +4867,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri const uint8x16_t m3b = vdupq_n_u8(0x3); const uint8x16_t mh = vdupq_n_u8(4); - int8x16x4_t q3bytes; + ggml_int8x16x4_t q3bytes; uint16_t aux16[2]; int8_t * scales = (int8_t *)aux16; @@ -4781,11 +4876,11 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri for (int i = 0; i < nb; ++i) { - uint8x16x4_t q3h; + ggml_uint8x16x4_t q3h; const uint8x8_t hbits = vld1_u8(x[i].hmask); const uint8x16_t q3bits = vld1q_u8(x[i].qs); - const int8x16x4_t q8bytes = vld1q_s8_x4(y[i].qs); + const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(y[i].qs); const uint16_t a = *(const uint16_t *)x[i].scales; aux16[0] = a & 0x0f0f; @@ -5134,8 +5229,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri const int32x4_t mzero = vdupq_n_s32(0); #endif - int8x16x2_t q4bytes; - int8x16x2_t q8bytes; + ggml_int8x16x2_t q4bytes; + ggml_int8x16x2_t q8bytes; float sumf = 0; @@ -5170,17 +5265,17 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri for (int j = 0; j < QK_K/64; ++j) { - const uint8x16x2_t q4bits = vld1q_u8_x2(q4); q4 += 32; + const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32; #ifdef __ARM_FEATURE_DOTPROD - q8bytes = vld1q_s8_x2(q8); q8 += 32; + q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); sumi1 += vaddvq_s32(p1) * scales[2*j+0]; - q8bytes = vld1q_s8_x2(q8); q8 += 32; + q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); @@ -5188,7 +5283,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri sumi2 += vaddvq_s32(p2) * scales[2*j+1]; #else - q8bytes = vld1q_s8_x2(q8); q8 += 32; + q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), @@ -5197,7 +5292,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0]; - q8bytes = vld1q_s8_x2(q8); q8 += 32; + q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), @@ -5512,8 +5607,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri float sumf = 0; - int8x16x2_t q4bytes; - int8x16x4_t q8bytes; + ggml_int8x16x2_t q4bytes; + ggml_int8x16x4_t q8bytes; float sum_mins = 0.f; @@ -5534,10 +5629,10 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri const float d = y[i].d * (float)x[i].d[0]; - const uint8x16x2_t q4bits = vld1q_u8_x2(q4); + const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); #ifdef __ARM_FEATURE_DOTPROD - q8bytes = vld1q_s8_x4(q8); + q8bytes = ggml_vld1q_s8_x4(q8); q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); @@ -5551,7 +5646,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri const int32_t sumi2 = vaddvq_s32(p2) * scales[1]; #else - q8bytes = vld1q_s8_x4(q8); + q8bytes = ggml_vld1q_s8_x4(q8); q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), @@ -5785,7 +5880,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri const int32x4_t mzero = vdupq_n_s32(0); #endif - int8x16x4_t q5bytes; + ggml_int8x16x4_t q5bytes; float sumf = 0; @@ -5815,16 +5910,16 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri const uint8_t * restrict qh = x[i].qh; const int8_t * restrict q8 = y[i].qs; - uint8x16x2_t qhbits = vld1q_u8_x2(qh); + ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); - uint8x16x4_t q5h; + ggml_uint8x16x4_t q5h; int32_t sumi = 0; for (int j = 0; j < QK_K/64; ++j) { - const uint8x16x2_t q5bits = vld1q_u8_x2(q5); q5 += 32; - const int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64; + const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5); q5 += 32; + const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); @@ -6218,8 +6313,8 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri const int32x4_t mzero = vdupq_n_s32(0); #endif - int8x16x4_t q5bytes; - uint8x16x4_t q5h; + ggml_int8x16x4_t q5bytes; + ggml_uint8x16x4_t q5h; float sumf = 0; @@ -6234,8 +6329,8 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri const uint8x8_t qhbits = vld1_u8(qh); - const uint8x16x2_t q5bits = vld1q_u8_x2(q5); - const int8x16x4_t q8bytes = vld1q_s8_x4(q8); + const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5); + const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1)); q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4)); @@ -6511,8 +6606,8 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri const uint8x16_t mone = vdupq_n_u8(3); - int8x16x4_t q6bytes; - uint8x16x4_t q6h; + ggml_int8x16x4_t q6bytes; + ggml_uint8x16x4_t q6h; for (int i = 0; i < nb; ++i) { @@ -6524,9 +6619,9 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri const int8_t * restrict scale = x[i].scales; - const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums); + const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums); const int8x16_t scales = vld1q_s8(scale); - const int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}; + const ggml_int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}; const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])), vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))), @@ -6538,9 +6633,9 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri for (int j = 0; j < QK_K/128; ++j) { - uint8x16x2_t qhbits = vld1q_u8_x2(qh); qh += 32; - uint8x16x4_t q6bits = vld1q_u8_x4(q6); q6 += 64; - int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64; + ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); qh += 32; + ggml_uint8x16x4_t q6bits = ggml_vld1q_u8_x4(q6); q6 += 64; + ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); @@ -6583,7 +6678,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri scale += 2; #endif - q8bytes = vld1q_s8_x4(q8); q8 += 64; + q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; shifted = vshrq_n_u8(qhbits.val[0], 4); q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4); @@ -6987,8 +7082,8 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri const uint8x16_t mone = vdupq_n_u8(3); - int8x16x4_t q6bytes; - uint8x16x4_t q6h; + ggml_int8x16x4_t q6bytes; + ggml_uint8x16x4_t q6h; for (int i = 0; i < nb; ++i) { @@ -7002,9 +7097,9 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri int32_t isum = 0; - uint8x16_t qhbits = vld1q_u8(qh); - uint8x16x2_t q6bits = vld1q_u8_x2(q6); - int8x16x4_t q8bytes = vld1q_s8_x4(q8); + uint8x16_t qhbits = vld1q_u8(qh); + ggml_uint8x16x2_t q6bits = ggml_vld1q_u8_x2(q6); + ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4); uint8x16_t shifted = vshrq_n_u8(qhbits, 2); -- cgit v1.2.3