summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRadosław Gryta <radek.gryta@gmail.com>2024-02-25 19:43:00 +0100
committerGitHub <noreply@github.com>2024-02-25 20:43:00 +0200
commitabbabc5e51d0d4656b438aec10b7fae9479ef37d (patch)
tree274a9c42f75ceef14e7541b810e2605027786290
parentf1a98c52546d009f742bdec2154c2a314ea950a6 (diff)
ggml-quants : provide ggml_vqtbl1q_u8 for 64bit compatibility (#5711)
* [ggml-quants] Provide ggml_vqtbl1q_u8 for 64bit compatibility vqtbl1q_u8 is not part of arm v7 neon library * [android-example] Remove abi filter after arm v7a fix * [github-workflows] Do not skip Android armeabi-v7a build
-rw-r--r--.github/workflows/build.yml3
-rw-r--r--examples/llama.android/app/build.gradle.kts8
-rw-r--r--ggml-quants.c33
3 files changed, 32 insertions, 12 deletions
diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index 03d76d45..66ad8593 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -669,8 +669,7 @@ jobs:
run: |
cd examples/llama.android
- # Skip armeabi-v7a for now (https://github.com/llvm/llvm-project/issues/65820).
- ./gradlew build --no-daemon -Pskip-armeabi-v7a
+ ./gradlew build --no-daemon
# freeBSD-latest:
# runs-on: macos-12
diff --git a/examples/llama.android/app/build.gradle.kts b/examples/llama.android/app/build.gradle.kts
index aadbe22c..d42140ef 100644
--- a/examples/llama.android/app/build.gradle.kts
+++ b/examples/llama.android/app/build.gradle.kts
@@ -21,12 +21,8 @@ android {
useSupportLibrary = true
}
ndk {
- // Workaround for https://github.com/llvm/llvm-project/issues/65820
- // affecting armeabi-v7a. Skip armeabi-v7a when invoked with
- // -Pskip-armeabi-v7a (e.g., ./gradlew build -Pskip-armeabi-v7a).
- if (project.hasProperty("skip-armeabi-v7a")) {
- abiFilters += listOf("arm64-v8a", "x86_64", "x86")
- }
+ // Add NDK properties if wanted, e.g.
+ // abiFilters += listOf("arm64-v8a")
}
externalNativeBuild {
cmake {
diff --git a/ggml-quants.c b/ggml-quants.c
index 5c5f2ce1..3d94d166 100644
--- a/ggml-quants.c
+++ b/ggml-quants.c
@@ -462,6 +462,30 @@ inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) {
return res;
}
+// NOTE: not tested
+inline static int8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) {
+ int8x16_t res;
+
+ res[ 0] = a[b[ 0]];
+ res[ 1] = a[b[ 1]];
+ res[ 2] = a[b[ 2]];
+ res[ 3] = a[b[ 3]];
+ res[ 4] = a[b[ 4]];
+ res[ 5] = a[b[ 5]];
+ res[ 6] = a[b[ 6]];
+ res[ 7] = a[b[ 7]];
+ res[ 8] = a[b[ 8]];
+ res[ 9] = a[b[ 9]];
+ res[10] = a[b[10]];
+ res[11] = a[b[11]];
+ res[12] = a[b[12]];
+ res[13] = a[b[13]];
+ res[14] = a[b[14]];
+ res[15] = a[b[15]];
+
+ return res;
+}
+
#else
#define ggml_int16x8x2_t int16x8x2_t
@@ -476,6 +500,7 @@ inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) {
#define ggml_vld1q_s8_x2 vld1q_s8_x2
#define ggml_vld1q_s8_x4 vld1q_s8_x4
#define ggml_vqtbl1q_s8 vqtbl1q_s8
+#define ggml_vqtbl1q_u8 vqtbl1q_u8
#endif
@@ -9488,8 +9513,8 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
qs += 16;
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16)));
- vs.val[1] = vandq_u8(vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
- vs.val[0] = vandq_u8(vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
+ vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
+ vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
vs.val[0] = vceqq_u8(vs.val[0], mask2);
vs.val[1] = vceqq_u8(vs.val[1], mask2);
@@ -9497,8 +9522,8 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
q3s.val[1] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[1], vreinterpretq_u8_u32(aux32x4_1))), vreinterpretq_s8_u8(vs.val[1]));
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | (signs[3] << 16)));
- vs.val[1] = vandq_u8(vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
- vs.val[0] = vandq_u8(vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
+ vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
+ vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
vs.val[0] = vceqq_u8(vs.val[0], mask2);
vs.val[1] = vceqq_u8(vs.val[1], mask2);