summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-06-01 15:24:33 +0300
committerGitHub <noreply@github.com>2025-06-01 15:24:33 +0300
commit7a8abe29f745cff95896095bf19cf247bdf2c661 (patch)
tree2a7a9622e7a8590b2d11fb28192665b26704fbf0
parent3df1a3a44d69490d074f22aa04ca542f2e72996f (diff)
Minor (~2%) iq2_ks TG performance improvement on CUDA (#468)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/ggml-cuda/iqk_mmvq.cu17
1 files changed, 10 insertions, 7 deletions
diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu
index 31cc1ecd..ae11ae14 100644
--- a/ggml/src/ggml-cuda/iqk_mmvq.cu
+++ b/ggml/src/ggml-cuda/iqk_mmvq.cu
@@ -804,11 +804,14 @@ __device__ __forceinline__ void vec_dot_iq2_ks_q8_1(
const uint8_t * a8 = (const uint8_t *)&aux32;
int v1, v2;
- int8_t s8[4];
- s8[0] = ((bq2->scales[2*(i4/4)+0] & 0xf) | ((extra >> 4) & 0x10)) - 16;
- s8[1] = ((bq2->scales[2*(i4/4)+0] >> 4) | ((extra >> 5) & 0x10)) - 16;
- s8[2] = ((bq2->scales[2*(i4/4)+1] & 0xf) | ((extra >> 6) & 0x10)) - 16;
- s8[3] = ((bq2->scales[2*(i4/4)+1] >> 4) | ((extra >> 7) & 0x10)) - 16;
+ int32_t scales32;
+ const uint16_t * scales16 = (const uint16_t *)bq2->scales;
+ scales32 = __vsub4((scales16[i4/4] | (scales16[i4/4] << 12)) & 0x0f0f0f0f, 0x10101010);
+ int8_t * s8 = (int8_t *)&scales32;
+ s8[0] += ((extra >> 4) & 0x10);
+ s8[1] += ((extra >> 6) & 0x10);
+ s8[2] += ((extra >> 5) & 0x10);
+ s8[3] += ((extra >> 7) & 0x10);
aux32[0] = ((val1 >> 0) & 0x03030303); aux32[1] = ((val2 >> 0) & 0x03030303); values = all_values + ((extra & 0x01) << 8);
v1 = int_from_table_4(a8 + 0, values);
@@ -818,12 +821,12 @@ __device__ __forceinline__ void vec_dot_iq2_ks_q8_1(
aux32[0] = ((val1 >> 2) & 0x03030303); aux32[1] = ((val2 >> 2) & 0x03030303); values = all_values + ((extra & 0x02) << 7);
v1 = int_from_table_4(a8 + 0, values);
v2 = int_from_table_4(a8 + 4, values);
- int sumi2 = ggml_cuda_dp4a(v2, q8_2[1], ggml_cuda_dp4a(v1, q8_2[0], 0)) * s8[1];
+ int sumi2 = ggml_cuda_dp4a(v2, q8_2[1], ggml_cuda_dp4a(v1, q8_2[0], 0)) * s8[2];
aux32[0] = ((val1 >> 4) & 0x03030303); aux32[1] = ((val2 >> 4) & 0x03030303); values = all_values + ((extra & 0x04) << 6);
v1 = int_from_table_4(a8 + 0, values);
v2 = int_from_table_4(a8 + 4, values);
- int sumi3 = ggml_cuda_dp4a(v2, q8_3[1], ggml_cuda_dp4a(v1, q8_3[0], 0)) * s8[2];
+ int sumi3 = ggml_cuda_dp4a(v2, q8_3[1], ggml_cuda_dp4a(v1, q8_3[0], 0)) * s8[1];
aux32[0] = ((val1 >> 6) & 0x03030303); aux32[1] = ((val2 >> 6) & 0x03030303); values = all_values + ((extra & 0x08) << 5);
v1 = int_from_table_4(a8 + 0, values);